Spaces:
Paused
Paused
x-lai
commited on
Commit
·
3d9fba4
1
Parent(s):
11d7ed8
Release training script
Browse filesFormer-commit-id: 6f951959fdf50617a5ad55be75bb9139e63fa04b
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -1
- README.md +79 -2
- chat.py +214 -144
- model/LISA.py +471 -194
- model/llava/conversation.py +50 -28
- model/llava/eval/eval_gpt_review.py +57 -43
- model/llava/eval/eval_gpt_review_visual.py +67 -48
- model/llava/eval/eval_science_qa.py +44 -40
- model/llava/eval/eval_science_qa_gpt4.py +36 -32
- model/llava/eval/eval_science_qa_gpt4_requery.py +85 -70
- model/llava/eval/generate_webpage_data_from_table.py +46 -38
- model/llava/eval/model_qa.py +30 -17
- model/llava/eval/model_vqa.py +113 -50
- model/llava/eval/model_vqa_science.py +136 -69
- model/llava/eval/qa_baseline_gpt35.py +35 -27
- model/llava/eval/run_llava.py +78 -33
- model/llava/eval/run_llava_batch.py +240 -69
- model/llava/eval/run_llava_batch_v2.py +240 -69
- model/llava/eval/run_llava_batch_v3.py +244 -70
- model/llava/eval/summarize_gpt_review.py +12 -10
- model/llava/model/__init__.py +2 -2
- model/llava/model/apply_delta.py +16 -8
- model/llava/model/consolidate.py +4 -2
- model/llava/model/llava.py +211 -65
- model/llava/model/llava_mpt.py +284 -76
- model/llava/model/make_delta.py +19 -8
- model/llava/model/mpt/adapt_tokenizer.py +11 -6
- model/llava/model/mpt/attention.py +289 -70
- model/llava/model/mpt/blocks.py +60 -11
- model/llava/model/mpt/configuration_mpt.py +104 -28
- model/llava/model/mpt/hf_prefixlm_converter.py +439 -104
- model/llava/model/mpt/meta_init_context.py +27 -10
- model/llava/model/mpt/modeling_mpt.py +278 -84
- model/llava/model/mpt/norm.py +67 -17
- model/llava/model/mpt/param_init_fns.py +290 -52
- model/llava/model/utils.py +21 -10
- model/llava/serve/cli.py +25 -20
- model/llava/serve/controller.py +35 -25
- model/llava/serve/gradio_css.py +3 -5
- model/llava/serve/gradio_patch.py +6 -7
- model/llava/serve/gradio_web_server.py +208 -92
- model/llava/serve/model_worker.py +170 -82
- model/llava/serve/test_message.py +18 -9
- model/llava/train/llama_flash_attn_monkey_patch.py +53 -41
- model/llava/train/llava_trainer.py +13 -9
- model/llava/train/train.py +211 -138
- model/llava/train/train_mem.py +2 -1
- model/llava/utils.py +19 -11
- model/segment_anything/__init__.py +3 -8
- model/segment_anything/automatic_mask_generator.py +26 -26
.gitignore
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
**/__pycache__
|
|
|
|
|
|
|
|
1 |
+
**/__pycache__
|
2 |
+
runs/
|
3 |
+
.vscode/
|
4 |
+
|
README.md
CHANGED
@@ -47,6 +47,83 @@ For more details, please refer to the [paper](https://arxiv.org/abs/2308.00692).
|
|
47 |
```
|
48 |
pip install -r requirements.txt
|
49 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
## Inference
|
52 |
To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explainatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explainatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
|
@@ -93,9 +170,9 @@ Important keys contained in JSON files:
|
|
93 |
|
94 |
The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
|
95 |
|
96 |
-
We provide a <a href="https://github.com/dvlab-research/LISA/blob/main/utils/
|
97 |
```
|
98 |
-
python3 utils/
|
99 |
```
|
100 |
|
101 |
Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
|
|
|
47 |
```
|
48 |
pip install -r requirements.txt
|
49 |
```
|
50 |
+
|
51 |
+
## Training
|
52 |
+
### Training Data Preparation
|
53 |
+
The training data consists of 4 types of data:
|
54 |
+
|
55 |
+
1. Semantic segmentation datasets: [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), [COCO-Stuff](https://github.com/nightrome/cocostuff#downloads), [Mapillary](https://www.mapillary.com/dataset/vistas), [PACO-LVIS](https://github.com/facebookresearch/paco/tree/main#dataset-setup), [PASCAL-Part](http://roozbehm.info/pascal-parts/pascal-parts.html)
|
56 |
+
|
57 |
+
2. Referring segmentation datasets: refCOCO, refCOCO+, refCOCOg [\[Download\]](https://github.com/lichengunc/refer#download)
|
58 |
+
|
59 |
+
3. Visual Question Answering dataset: [LLaVA-Instruct-150k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json)
|
60 |
+
|
61 |
+
4. Reasoning segmentation dataset: [ReasonSeg](https://github.com/dvlab-research/LISA#dataset)
|
62 |
+
|
63 |
+
Download them from the above links, and organize them as follows.
|
64 |
+
|
65 |
+
```
|
66 |
+
├── dataset
|
67 |
+
│ ├── ade20k
|
68 |
+
│ │ ├── annotations
|
69 |
+
│ │ └── images
|
70 |
+
│ ├── coco
|
71 |
+
│ │ └── train2017
|
72 |
+
│ ├── cocostuff
|
73 |
+
│ │ ├── annotations
|
74 |
+
│ │ └── train2017
|
75 |
+
│ ├── llava_dataset
|
76 |
+
│ │ └── llava_instruct_150k.json
|
77 |
+
│ ├── mapillary
|
78 |
+
│ │ ├── config_v2.0.json
|
79 |
+
│ │ ├── testing
|
80 |
+
│ │ ├── training
|
81 |
+
│ │ └── validation
|
82 |
+
│ ├── reason_seg
|
83 |
+
│ │ └── ReasonSeg
|
84 |
+
│ │ ├── train
|
85 |
+
│ │ ├── val
|
86 |
+
│ │ └── explanatory
|
87 |
+
│ ├── refer_seg
|
88 |
+
│ │ ├── images
|
89 |
+
│ │ | ├── saiapr_tc-12
|
90 |
+
│ │ | └── mscoco
|
91 |
+
│ │ | └── images
|
92 |
+
│ │ | └── train2014
|
93 |
+
│ │ ├── refclef
|
94 |
+
│ │ ├── refcoco
|
95 |
+
│ │ ├── refcoco+
|
96 |
+
│ │ └── refcocog
|
97 |
+
│ └── vlpart
|
98 |
+
│ ├── paco
|
99 |
+
│ │ └── annotations
|
100 |
+
│ └── pascal_part
|
101 |
+
│ ├── train.json
|
102 |
+
│ └── VOCdevkit
|
103 |
+
```
|
104 |
+
|
105 |
+
### Pre-trained weights
|
106 |
+
|
107 |
+
#### LLaVA
|
108 |
+
To train LISA-7B or 13B, you need to follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. Typically, we use the final weights `LLaVA-Lightning-7B-v1-1` and `LLaVA-13B-v1-1` merged from `liuhaotian/LLaVA-Lightning-7B-delta-v1-1` and `liuhaotian/LLaVA-13b-delta-v1-1`, respectively. For Llama2, we can directly use the LLaVA full weights `liuhaotian/llava-llama-2-13b-chat-lightning-preview`.
|
109 |
+
|
110 |
+
#### SAM ViT-H weights
|
111 |
+
Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
|
112 |
+
|
113 |
+
### Training
|
114 |
+
```
|
115 |
+
deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b"
|
116 |
+
```
|
117 |
+
When training is finished, to get the full model weight:
|
118 |
+
```
|
119 |
+
cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
|
120 |
+
```
|
121 |
+
|
122 |
+
### Validation
|
123 |
+
```
|
124 |
+
deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b" --weight='PATH_TO_pytorch_model.bin' --eval_only
|
125 |
+
```
|
126 |
+
|
127 |
|
128 |
## Inference
|
129 |
To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explainatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explainatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
|
|
|
170 |
|
171 |
The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
|
172 |
|
173 |
+
We provide a <a href="https://github.com/dvlab-research/LISA/blob/main/utils/data_processing.py">**script**</a> that demonstrates how to process the annotations:
|
174 |
```
|
175 |
+
python3 utils/data_processing.py
|
176 |
```
|
177 |
|
178 |
Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
|
chat.py
CHANGED
@@ -1,38 +1,48 @@
|
|
1 |
-
import
|
2 |
import os
|
|
|
|
|
3 |
import cv2
|
4 |
-
import argparse
|
5 |
-
import torch
|
6 |
-
import transformers
|
7 |
import numpy as np
|
|
|
8 |
import torch.nn.functional as F
|
9 |
-
|
10 |
from transformers import AutoTokenizer, CLIPImageProcessor
|
11 |
|
12 |
from model.LISA import LISA
|
13 |
-
from utils.conversation import get_default_conv_template
|
14 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
|
|
|
|
15 |
|
16 |
def parse_args(args):
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
34 |
-
img_size=1024
|
35 |
-
|
36 |
"""Normalize pixel values and pad to a square input."""
|
37 |
# Normalize colors
|
38 |
x = (x - pixel_mean) / pixel_std
|
@@ -45,125 +55,185 @@ def preprocess(x,
|
|
45 |
|
46 |
|
47 |
def main(args):
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
)
|
59 |
-
tokenizer.pad_token = tokenizer.unk_token
|
60 |
-
num_added_tokens = tokenizer.add_tokens('[SEG]')
|
61 |
-
ret_token_idx = tokenizer('[SEG]', add_special_tokens=False).input_ids
|
62 |
-
args.seg_token_idx = ret_token_idx[0]
|
63 |
-
|
64 |
-
model = LISA(
|
65 |
-
args.local_rank,
|
66 |
-
args.seg_token_idx,
|
67 |
-
tokenizer,
|
68 |
-
args.version,
|
69 |
-
args.lora_r,
|
70 |
-
args.precision,
|
71 |
-
load_in_8bit=args.load_in_8bit,
|
72 |
-
load_in_4bit=args.load_in_4bit,
|
73 |
-
)
|
74 |
-
|
75 |
-
weight = {}
|
76 |
-
visual_model_weight = torch.load(os.path.join(args.version, "pytorch_model-visual_model.bin"))
|
77 |
-
text_hidden_fcs_weight = torch.load(os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin"))
|
78 |
-
weight.update(visual_model_weight)
|
79 |
-
weight.update(text_hidden_fcs_weight)
|
80 |
-
missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False)
|
81 |
-
|
82 |
-
if args.precision == 'bf16':
|
83 |
-
model = model.bfloat16().cuda()
|
84 |
-
elif args.precision == 'fp16':
|
85 |
-
import deepspeed
|
86 |
-
model_engine = deepspeed.init_inference(model=model,
|
87 |
-
dtype=torch.half,
|
88 |
-
replace_with_kernel_inject=True,
|
89 |
-
replace_method="auto",
|
90 |
)
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().half()
|
131 |
-
else:
|
132 |
-
images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().float()
|
133 |
-
images = transform.apply_image(image)
|
134 |
-
resize_list = [images.shape[:2]]
|
135 |
-
if args.precision == 'bf16':
|
136 |
-
images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().bfloat16()
|
137 |
-
elif args.precision == 'fp16':
|
138 |
-
images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().half()
|
139 |
else:
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
import cv2
|
|
|
|
|
|
|
6 |
import numpy as np
|
7 |
+
import torch
|
8 |
import torch.nn.functional as F
|
9 |
+
import transformers
|
10 |
from transformers import AutoTokenizer, CLIPImageProcessor
|
11 |
|
12 |
from model.LISA import LISA
|
|
|
13 |
from model.segment_anything.utils.transforms import ResizeLongestSide
|
14 |
+
from utils.conversation import get_default_conv_template
|
15 |
+
|
16 |
|
17 |
def parse_args(args):
|
18 |
+
parser = argparse.ArgumentParser(description="LISA chat")
|
19 |
+
parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v0")
|
20 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
21 |
+
parser.add_argument(
|
22 |
+
"--precision",
|
23 |
+
default="bf16",
|
24 |
+
type=str,
|
25 |
+
choices=["fp32", "bf16", "fp16"],
|
26 |
+
help="precision for inference",
|
27 |
+
)
|
28 |
+
parser.add_argument("--image-size", default=1024, type=int, help="image size")
|
29 |
+
parser.add_argument("--model-max-length", default=512, type=int)
|
30 |
+
parser.add_argument("--lora-r", default=-1, type=int)
|
31 |
+
parser.add_argument(
|
32 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
33 |
+
)
|
34 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
35 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
36 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
37 |
+
return parser.parse_args(args)
|
38 |
+
|
39 |
+
|
40 |
+
def preprocess(
|
41 |
+
x,
|
42 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
43 |
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
44 |
+
img_size=1024,
|
45 |
+
) -> torch.Tensor:
|
46 |
"""Normalize pixel values and pad to a square input."""
|
47 |
# Normalize colors
|
48 |
x = (x - pixel_mean) / pixel_std
|
|
|
55 |
|
56 |
|
57 |
def main(args):
|
58 |
+
args = parse_args(args)
|
59 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
60 |
+
|
61 |
+
# Create model
|
62 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
63 |
+
args.version,
|
64 |
+
cache_dir=None,
|
65 |
+
model_max_length=args.model_max_length,
|
66 |
+
padding_side="right",
|
67 |
+
use_fast=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
)
|
69 |
+
tokenizer.pad_token = tokenizer.unk_token
|
70 |
+
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
71 |
+
ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids
|
72 |
+
args.seg_token_idx = ret_token_idx[0]
|
73 |
+
|
74 |
+
model = LISA(
|
75 |
+
args.local_rank,
|
76 |
+
args.seg_token_idx,
|
77 |
+
tokenizer,
|
78 |
+
args.version,
|
79 |
+
args.lora_r,
|
80 |
+
args.precision,
|
81 |
+
load_in_8bit=args.load_in_8bit,
|
82 |
+
load_in_4bit=args.load_in_4bit,
|
83 |
+
)
|
84 |
+
|
85 |
+
weight = {}
|
86 |
+
visual_model_weight = torch.load(
|
87 |
+
os.path.join(args.version, "pytorch_model-visual_model.bin")
|
88 |
+
)
|
89 |
+
text_hidden_fcs_weight = torch.load(
|
90 |
+
os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin")
|
91 |
+
)
|
92 |
+
weight.update(visual_model_weight)
|
93 |
+
weight.update(text_hidden_fcs_weight)
|
94 |
+
missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False)
|
95 |
+
|
96 |
+
if args.precision == "bf16":
|
97 |
+
model = model.bfloat16().cuda()
|
98 |
+
elif args.precision == "fp16":
|
99 |
+
import deepspeed
|
100 |
+
|
101 |
+
model_engine = deepspeed.init_inference(
|
102 |
+
model=model,
|
103 |
+
dtype=torch.half,
|
104 |
+
replace_with_kernel_inject=True,
|
105 |
+
replace_method="auto",
|
106 |
+
)
|
107 |
+
model = model_engine.module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
else:
|
109 |
+
model = model.float().cuda()
|
110 |
+
|
111 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
112 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
113 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
114 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
115 |
+
image_token_len = 256
|
116 |
+
|
117 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower)
|
118 |
+
transform = ResizeLongestSide(args.image_size)
|
119 |
+
|
120 |
+
while True:
|
121 |
+
conv = get_default_conv_template("vicuna").copy()
|
122 |
+
conv.messages = []
|
123 |
+
|
124 |
+
prompt = input("Please input your prompt: ")
|
125 |
+
prompt = DEFAULT_IMAGE_TOKEN + " " + prompt
|
126 |
+
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
127 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
128 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
129 |
+
|
130 |
+
conv.append_message(conv.roles[0], prompt)
|
131 |
+
conv.append_message(conv.roles[1], "")
|
132 |
+
prompt = conv.get_prompt()
|
133 |
+
|
134 |
+
image_path = input("Please input the image path: ")
|
135 |
+
if not os.path.exists(image_path):
|
136 |
+
print("File not found in {}".format(image_path))
|
137 |
+
continue
|
138 |
+
|
139 |
+
image = cv2.imread(image_path)
|
140 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
141 |
+
original_size_list = [image.shape[:2]]
|
142 |
+
if args.precision == "bf16":
|
143 |
+
images_clip = (
|
144 |
+
clip_image_processor.preprocess(image, return_tensors="pt")[
|
145 |
+
"pixel_values"
|
146 |
+
][0]
|
147 |
+
.unsqueeze(0)
|
148 |
+
.cuda()
|
149 |
+
.bfloat16()
|
150 |
+
)
|
151 |
+
elif args.precision == "fp16":
|
152 |
+
images_clip = (
|
153 |
+
clip_image_processor.preprocess(image, return_tensors="pt")[
|
154 |
+
"pixel_values"
|
155 |
+
][0]
|
156 |
+
.unsqueeze(0)
|
157 |
+
.cuda()
|
158 |
+
.half()
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
images_clip = (
|
162 |
+
clip_image_processor.preprocess(image, return_tensors="pt")[
|
163 |
+
"pixel_values"
|
164 |
+
][0]
|
165 |
+
.unsqueeze(0)
|
166 |
+
.cuda()
|
167 |
+
.float()
|
168 |
+
)
|
169 |
+
images = transform.apply_image(image)
|
170 |
+
resize_list = [images.shape[:2]]
|
171 |
+
if args.precision == "bf16":
|
172 |
+
images = (
|
173 |
+
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
174 |
+
.unsqueeze(0)
|
175 |
+
.cuda()
|
176 |
+
.bfloat16()
|
177 |
+
)
|
178 |
+
elif args.precision == "fp16":
|
179 |
+
images = (
|
180 |
+
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
181 |
+
.unsqueeze(0)
|
182 |
+
.cuda()
|
183 |
+
.half()
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
images = (
|
187 |
+
preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
|
188 |
+
.unsqueeze(0)
|
189 |
+
.cuda()
|
190 |
+
.float()
|
191 |
+
)
|
192 |
+
|
193 |
+
input_ids = tokenizer(prompt).input_ids
|
194 |
+
input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda()
|
195 |
+
output_ids, pred_masks = model.evaluate(
|
196 |
+
images_clip,
|
197 |
+
images,
|
198 |
+
input_ids,
|
199 |
+
resize_list,
|
200 |
+
original_size_list,
|
201 |
+
max_new_tokens=512,
|
202 |
+
tokenizer=tokenizer,
|
203 |
+
)
|
204 |
+
text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
205 |
+
text_output = (
|
206 |
+
text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "")
|
207 |
+
.replace("\n", "")
|
208 |
+
.replace(" ", "")
|
209 |
+
)
|
210 |
+
|
211 |
+
print("text_output: ", text_output)
|
212 |
+
for i, pred_mask in enumerate(pred_masks):
|
213 |
+
if pred_mask.shape[0] == 0:
|
214 |
+
continue
|
215 |
+
|
216 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
217 |
+
pred_mask = pred_mask > 0
|
218 |
+
|
219 |
+
save_path = "{}/{}_mask_{}.jpg".format(
|
220 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
221 |
+
)
|
222 |
+
cv2.imwrite(save_path, pred_mask * 100)
|
223 |
+
print("{} has been saved.".format(save_path))
|
224 |
+
|
225 |
+
save_path = "{}/{}_masked_img_{}.jpg".format(
|
226 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
227 |
+
)
|
228 |
+
save_img = image.copy()
|
229 |
+
save_img[pred_mask] = (
|
230 |
+
image * 0.5
|
231 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
232 |
+
)[pred_mask]
|
233 |
+
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
234 |
+
cv2.imwrite(save_path, save_img)
|
235 |
+
print("{} has been saved.".format(save_path))
|
236 |
+
|
237 |
+
|
238 |
+
if __name__ == "__main__":
|
239 |
+
main(sys.argv[1:])
|
model/LISA.py
CHANGED
@@ -1,213 +1,490 @@
|
|
1 |
-
from typing import
|
2 |
-
|
3 |
-
import glob
|
4 |
-
import math
|
5 |
-
import numpy as np
|
6 |
-
import os
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
-
import
|
11 |
-
|
12 |
-
from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig
|
13 |
-
from peft import (
|
14 |
-
LoraConfig,
|
15 |
-
get_peft_model,
|
16 |
-
get_peft_model_state_dict,
|
17 |
-
prepare_model_for_int8_training,
|
18 |
-
set_peft_model_state_dict,
|
19 |
-
)
|
20 |
-
from .llava.model.llava import LlavaLlamaForCausalLM
|
21 |
-
from .segment_anything import build_sam_vit_l, build_sam_vit_h
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
def
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
if 'lm_head' in lora_module_names: # needed for 16-bit
|
37 |
-
lora_module_names.remove('lm_head')
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
return sorted(list(lora_module_names))
|
43 |
|
44 |
class LISA(nn.Module):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto',
|
77 |
-
quantization_config=BitsAndBytesConfig(
|
78 |
-
load_in_4bit=True,
|
79 |
-
bnb_4bit_compute_dtype=torch.float16,
|
80 |
-
bnb_4bit_use_double_quant=True,
|
81 |
-
bnb_4bit_quant_type='nf4'
|
82 |
-
)
|
83 |
-
)
|
84 |
-
elif load_in_8bit:
|
85 |
-
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto')
|
86 |
-
else:
|
87 |
-
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True)
|
88 |
-
else:
|
89 |
-
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
|
90 |
-
|
91 |
-
self.lm.enable_input_require_grads()
|
92 |
-
self.lm.gradient_checkpointing_enable()
|
93 |
-
self.lm.config.use_cache = False
|
94 |
-
model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision)
|
95 |
-
vision_config = model_vision_dict['vision_config']
|
96 |
-
vision_tower = self.lm.get_model().vision_tower[0]
|
97 |
-
self.lm.model.config.eos_token_id = tokenizer.eos_token_id
|
98 |
-
self.lm.model.config.bos_token_id = tokenizer.bos_token_id
|
99 |
-
self.lm.model.config.pad_token_id = tokenizer.pad_token_id
|
100 |
-
|
101 |
-
if vision_tower.device.type == 'meta':
|
102 |
-
if precision == 'bf16':
|
103 |
-
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
|
104 |
-
elif precision == 'fp16':
|
105 |
-
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank)
|
106 |
-
else:
|
107 |
-
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
|
108 |
-
self.lm.get_model().vision_tower[0] = vision_tower
|
109 |
-
else:
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
if precision == "bf16":
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
113 |
elif precision == "fp16":
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
else:
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
self.seg_token_idx = seg_token_idx
|
132 |
-
self.lm.resize_token_embeddings(len(tokenizer))
|
133 |
-
|
134 |
-
for n, p in self.lm.named_parameters():
|
135 |
-
if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer):
|
136 |
-
p.requires_grad = True
|
137 |
-
|
138 |
-
# SAM
|
139 |
-
self.visual_model = build_sam_vit_h(None)
|
140 |
-
for param in self.visual_model.parameters():
|
141 |
-
param.requires_grad = False
|
142 |
-
if train_mask_decoder:
|
143 |
-
self.visual_model.mask_decoder.train()
|
144 |
-
for param in self.visual_model.mask_decoder.parameters():
|
145 |
-
param.requires_grad = True
|
146 |
-
|
147 |
-
# Projection layer
|
148 |
-
in_dim = self.lm.config.hidden_size
|
149 |
-
text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)]
|
150 |
-
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
|
151 |
-
|
152 |
-
def get_visual_embs(self, pixel_values: torch.FloatTensor):
|
153 |
-
image_embeddings = self.visual_model.image_encoder(pixel_values)
|
154 |
-
return image_embeddings
|
155 |
-
|
156 |
-
def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
|
157 |
-
|
158 |
-
with torch.no_grad():
|
159 |
-
outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
|
160 |
-
output_hidden_states = outputs.hidden_states[-1]
|
161 |
-
output_ids = outputs.sequences
|
162 |
-
|
163 |
-
seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
|
164 |
-
|
165 |
-
last_embedding = None
|
166 |
-
last_output_logit = None
|
167 |
-
hidden_states = []
|
168 |
-
|
169 |
-
assert len(self.text_hidden_fcs) == 1
|
170 |
-
hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
|
171 |
-
|
172 |
-
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
173 |
-
pred_embeddings = last_hidden_state[seg_token_mask]
|
174 |
-
|
175 |
-
seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
|
176 |
-
seg_token_offset = seg_token_counts.cumsum(-1)
|
177 |
-
seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
|
178 |
-
|
179 |
-
pred_embeddings_ = []
|
180 |
-
for i in range(len(seg_token_offset)-1):
|
181 |
-
start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
|
182 |
-
pred_embeddings_.append(pred_embeddings[start_i: end_i])
|
183 |
-
pred_embeddings = pred_embeddings_
|
184 |
-
|
185 |
-
image_embeddings = self.get_visual_embs(images)
|
186 |
-
|
187 |
-
multimask_output = False
|
188 |
-
pred_masks = []
|
189 |
-
for i in range(len(pred_embeddings)):
|
190 |
-
sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
|
191 |
-
points=None,
|
192 |
-
boxes=None,
|
193 |
-
masks=None,
|
194 |
-
text_embeds=pred_embeddings[i].unsqueeze(1),
|
195 |
)
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
)
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
+
from peft import (LoraConfig, get_peft_model)
|
7 |
+
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
from .llava.model.llava import LlavaLlamaForCausalLM
|
10 |
+
from .segment_anything import build_sam_vit_h
|
11 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
12 |
+
DEFAULT_IMAGE_PATCH_TOKEN)
|
13 |
|
14 |
+
def dice_loss(
|
15 |
+
inputs: torch.Tensor,
|
16 |
+
targets: torch.Tensor,
|
17 |
+
num_masks: float,
|
18 |
+
scale=1000, # 100000.0,
|
19 |
+
eps=1e-6,
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
23 |
+
Args:
|
24 |
+
inputs: A float tensor of arbitrary shape.
|
25 |
+
The predictions for each example.
|
26 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
27 |
+
classification label for each element in inputs
|
28 |
+
(0 for the negative class and 1 for the positive class).
|
29 |
+
"""
|
30 |
+
inputs = inputs.sigmoid()
|
31 |
+
inputs = inputs.flatten(1, 2)
|
32 |
+
targets = targets.flatten(1, 2)
|
33 |
+
numerator = 2 * (inputs / scale * targets).sum(-1)
|
34 |
+
denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
|
35 |
+
loss = 1 - (numerator + eps) / (denominator + eps)
|
36 |
+
loss = loss.sum() / (num_masks + 1e-8)
|
37 |
+
return loss
|
38 |
|
|
|
|
|
39 |
|
40 |
+
def sigmoid_ce_loss(
|
41 |
+
inputs: torch.Tensor,
|
42 |
+
targets: torch.Tensor,
|
43 |
+
num_masks: float,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
inputs: A float tensor of arbitrary shape.
|
48 |
+
The predictions for each example.
|
49 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
50 |
+
classification label for each element in inputs
|
51 |
+
(0 for the negative class and 1 for the positive class).
|
52 |
+
Returns:
|
53 |
+
Loss tensor
|
54 |
+
"""
|
55 |
+
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
56 |
+
loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
|
57 |
+
return loss
|
58 |
|
|
|
59 |
|
60 |
class LISA(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
local_rank,
|
64 |
+
seg_token_idx,
|
65 |
+
tokenizer,
|
66 |
+
llm_version,
|
67 |
+
lora_r,
|
68 |
+
precision,
|
69 |
+
load_in_4bit=False,
|
70 |
+
load_in_8bit=False,
|
71 |
+
lora_target_modules=["q_proj", "v_proj"],
|
72 |
+
lora_alpha=16,
|
73 |
+
lora_dropout=0.05,
|
74 |
+
vision_tower="openai/clip-vit-large-patch14",
|
75 |
+
mm_vision_select_layer=-2,
|
76 |
+
freeze_lm=True,
|
77 |
+
train_mask_decoder=True,
|
78 |
+
out_dim=256,
|
79 |
+
ce_loss_weight=1.0,
|
80 |
+
dice_loss_weight=0.5,
|
81 |
+
bce_loss_weight=2.0,
|
82 |
+
vision_pretrained=None,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.local_rank = local_rank
|
86 |
+
self.tokenizer = tokenizer
|
87 |
+
self.image_token = tokenizer.cls_token_id
|
88 |
+
self.precision = precision
|
89 |
+
self.ce_loss_weight = ce_loss_weight
|
90 |
+
self.dice_loss_weight = dice_loss_weight
|
91 |
+
self.bce_loss_weight = bce_loss_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
# LLaVA
|
94 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
95 |
+
num_new_tokens = tokenizer.add_tokens(
|
96 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
97 |
+
)
|
98 |
if precision == "bf16":
|
99 |
+
self.lm = LlavaLlamaForCausalLM.from_pretrained(
|
100 |
+
llm_version,
|
101 |
+
torch_dtype=torch.bfloat16,
|
102 |
+
cache_dir=None,
|
103 |
+
low_cpu_mem_usage=True,
|
104 |
+
)
|
105 |
elif precision == "fp16":
|
106 |
+
if load_in_4bit:
|
107 |
+
self.lm = LlavaLlamaForCausalLM.from_pretrained(
|
108 |
+
llm_version,
|
109 |
+
load_in_4bit=True,
|
110 |
+
cache_dir=None,
|
111 |
+
low_cpu_mem_usage=True,
|
112 |
+
device_map="auto",
|
113 |
+
quantization_config=BitsAndBytesConfig(
|
114 |
+
load_in_4bit=True,
|
115 |
+
bnb_4bit_compute_dtype=torch.float16,
|
116 |
+
bnb_4bit_use_double_quant=True,
|
117 |
+
bnb_4bit_quant_type="nf4",
|
118 |
+
),
|
119 |
+
)
|
120 |
+
elif load_in_8bit:
|
121 |
+
self.lm = LlavaLlamaForCausalLM.from_pretrained(
|
122 |
+
llm_version,
|
123 |
+
load_in_8bit=True,
|
124 |
+
cache_dir=None,
|
125 |
+
low_cpu_mem_usage=True,
|
126 |
+
device_map="auto",
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
self.lm = LlavaLlamaForCausalLM.from_pretrained(
|
130 |
+
llm_version,
|
131 |
+
torch_dtype=torch.half,
|
132 |
+
cache_dir=None,
|
133 |
+
low_cpu_mem_usage=True,
|
134 |
+
)
|
135 |
else:
|
136 |
+
self.lm = LlavaLlamaForCausalLM.from_pretrained(
|
137 |
+
llm_version,
|
138 |
+
torch_dtype=torch.float32,
|
139 |
+
cache_dir=None,
|
140 |
+
low_cpu_mem_usage=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
self.lm.enable_input_require_grads()
|
144 |
+
self.lm.gradient_checkpointing_enable()
|
145 |
+
self.lm.config.use_cache = False
|
146 |
+
model_vision_dict = self.lm.get_model().initialize_vision_modules(
|
147 |
+
vision_tower=vision_tower,
|
148 |
+
mm_vision_select_layer=mm_vision_select_layer,
|
149 |
+
precision=precision,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
)
|
151 |
+
vision_config = model_vision_dict["vision_config"]
|
152 |
+
vision_tower = self.lm.get_model().vision_tower[0]
|
153 |
+
self.lm.model.config.eos_token_id = tokenizer.eos_token_id
|
154 |
+
self.lm.model.config.bos_token_id = tokenizer.bos_token_id
|
155 |
+
self.lm.model.config.pad_token_id = tokenizer.pad_token_id
|
156 |
|
157 |
+
if vision_tower.device.type == "meta":
|
158 |
+
if precision == "bf16":
|
159 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
160 |
+
vision_tower.config._name_or_path,
|
161 |
+
torch_dtype=torch.bfloat16,
|
162 |
+
low_cpu_mem_usage=True,
|
163 |
+
).cuda(local_rank)
|
164 |
+
elif precision == "fp16":
|
165 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
166 |
+
vision_tower.config._name_or_path,
|
167 |
+
torch_dtype=torch.half,
|
168 |
+
low_cpu_mem_usage=True,
|
169 |
+
).cuda(local_rank)
|
170 |
+
else:
|
171 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
172 |
+
vision_tower.config._name_or_path,
|
173 |
+
torch_dtype=torch.float32,
|
174 |
+
low_cpu_mem_usage=True,
|
175 |
+
).cuda(local_rank)
|
176 |
+
self.lm.get_model().vision_tower[0] = vision_tower
|
177 |
+
else:
|
178 |
+
if precision == "bf16":
|
179 |
+
vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
180 |
+
elif precision == "fp16":
|
181 |
+
vision_tower.to(device="cuda", dtype=torch.half)
|
182 |
+
else:
|
183 |
+
vision_tower.to(device="cuda", dtype=torch.float32)
|
184 |
+
|
185 |
+
self.lm.config.tune_mm_mlp_adapter = False
|
186 |
+
self.lm.config.freeze_mm_mlp_adapter = False
|
187 |
+
self.lm.config.mm_use_im_start_end = True
|
188 |
+
vision_config.use_im_start_end = True
|
189 |
+
self.lm.config.sep_image_conv_front = False
|
190 |
+
|
191 |
+
self.lm.initialize_vision_tokenizer(
|
192 |
+
mm_use_im_start_end=True,
|
193 |
+
tokenizer=tokenizer,
|
194 |
+
num_new_tokens=num_new_tokens,
|
195 |
+
device=local_rank,
|
196 |
+
tune_mm_mlp_adapter=False,
|
197 |
)
|
198 |
+
if freeze_lm:
|
199 |
+
for n, param in self.lm.named_parameters():
|
200 |
+
param.requires_grad = False
|
201 |
+
|
202 |
+
# LoRA
|
203 |
+
if lora_r > 0:
|
204 |
+
config = LoraConfig(
|
205 |
+
r=lora_r,
|
206 |
+
lora_alpha=lora_alpha,
|
207 |
+
target_modules=lora_target_modules,
|
208 |
+
lora_dropout=lora_dropout,
|
209 |
+
bias="none",
|
210 |
+
task_type="CAUSAL_LM",
|
211 |
+
)
|
212 |
+
self.lm = get_peft_model(self.lm, config)
|
213 |
+
self.lm.print_trainable_parameters()
|
214 |
+
|
215 |
+
self.llm_version = llm_version
|
216 |
+
|
217 |
+
self.seg_token_idx = seg_token_idx
|
218 |
+
self.lm.resize_token_embeddings(len(tokenizer))
|
219 |
+
|
220 |
+
for n, p in self.lm.named_parameters():
|
221 |
+
if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(tokenizer):
|
222 |
+
p.requires_grad = True
|
223 |
|
224 |
+
# SAM
|
225 |
+
self.visual_model = build_sam_vit_h(vision_pretrained)
|
226 |
+
for param in self.visual_model.parameters():
|
227 |
+
param.requires_grad = False
|
228 |
+
if train_mask_decoder:
|
229 |
+
self.visual_model.mask_decoder.train()
|
230 |
+
for param in self.visual_model.mask_decoder.parameters():
|
231 |
+
param.requires_grad = True
|
232 |
+
|
233 |
+
# Projection layer
|
234 |
+
in_dim = self.lm.config.hidden_size
|
235 |
+
text_fc = [
|
236 |
+
nn.Linear(in_dim, in_dim),
|
237 |
+
nn.ReLU(inplace=True),
|
238 |
+
nn.Linear(in_dim, out_dim),
|
239 |
+
nn.Dropout(0.0),
|
240 |
+
]
|
241 |
+
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
|
242 |
+
|
243 |
+
def get_visual_embs(self, pixel_values: torch.FloatTensor):
|
244 |
+
with torch.no_grad():
|
245 |
+
image_embeddings = self.visual_model.image_encoder(pixel_values)
|
246 |
+
return image_embeddings
|
247 |
+
|
248 |
+
def forward(
|
249 |
+
self,
|
250 |
+
images: torch.FloatTensor,
|
251 |
+
images_clip: torch.FloatTensor,
|
252 |
+
input_ids: torch.LongTensor,
|
253 |
+
labels: torch.LongTensor,
|
254 |
+
attention_masks: torch.LongTensor,
|
255 |
+
offset: torch.LongTensor,
|
256 |
+
masks_list: List[torch.FloatTensor],
|
257 |
+
label_list: List[torch.Tensor],
|
258 |
+
resize_list: List[tuple],
|
259 |
+
inference: bool = False,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
image_embeddings = self.get_visual_embs(images)
|
263 |
+
batch_size = image_embeddings.shape[0]
|
264 |
+
assert batch_size == len(offset) - 1
|
265 |
+
|
266 |
+
seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
|
267 |
+
seg_token_mask = torch.cat(
|
268 |
+
[
|
269 |
+
seg_token_mask,
|
270 |
+
torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(self.local_rank),
|
271 |
+
],
|
272 |
+
dim=1,
|
273 |
)
|
274 |
+
|
275 |
+
if inference:
|
276 |
+
n_batch = 1
|
277 |
+
length = input_ids.shape[0]
|
278 |
+
assert images_clip.shape[0] == 1
|
279 |
+
images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
|
280 |
+
|
281 |
+
output_hidden_states = []
|
282 |
+
for i in range(n_batch):
|
283 |
+
start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
|
284 |
+
output_i = self.lm(
|
285 |
+
images=images_clip_extend[: end_i - start_i],
|
286 |
+
attention_mask=attention_masks[start_i:end_i],
|
287 |
+
input_ids=input_ids[start_i:end_i],
|
288 |
+
output_hidden_states=True,
|
289 |
+
)
|
290 |
+
output_hidden_states.append(output_i.hidden_states)
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
|
293 |
+
output_hidden_states_list = []
|
294 |
+
output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
|
295 |
+
output_hidden_states_list.append(output_hidden_states_level)
|
296 |
+
output_hidden_states = output_hidden_states_list
|
297 |
+
output = None
|
298 |
+
|
299 |
+
else:
|
300 |
+
images_clip_list = []
|
301 |
+
for i in range(len(offset) - 1):
|
302 |
+
start_i, end_i = offset[i], offset[i + 1]
|
303 |
+
images_clip_i = (
|
304 |
+
images_clip[i]
|
305 |
+
.unsqueeze(0)
|
306 |
+
.expand(end_i - start_i, -1, -1, -1)
|
307 |
+
.contiguous()
|
308 |
+
)
|
309 |
+
images_clip_list.append(images_clip_i)
|
310 |
+
images_clip = torch.cat(images_clip_list, dim=0)
|
311 |
+
|
312 |
+
output = self.lm(
|
313 |
+
images=images_clip,
|
314 |
+
attention_mask=attention_masks,
|
315 |
+
input_ids=input_ids,
|
316 |
+
labels=labels,
|
317 |
+
output_hidden_states=True,
|
318 |
+
)
|
319 |
+
output_hidden_states = output.hidden_states
|
320 |
+
|
321 |
+
hidden_states = []
|
322 |
+
|
323 |
+
assert len(self.text_hidden_fcs) == 1
|
324 |
+
hidden_states.append(self.text_hidden_fcs[0](output_hidden_states[-1]))
|
325 |
+
|
326 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
327 |
+
|
328 |
+
pred_embeddings = last_hidden_state[seg_token_mask]
|
329 |
+
seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
|
330 |
+
|
331 |
+
seg_token_offset = seg_token_counts.cumsum(-1)
|
332 |
+
seg_token_offset = torch.cat(
|
333 |
+
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
|
334 |
+
)
|
335 |
+
|
336 |
+
seg_token_offset = seg_token_offset[offset]
|
337 |
+
|
338 |
+
pred_embeddings_ = []
|
339 |
+
for i in range(len(seg_token_offset) - 1):
|
340 |
+
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
|
341 |
+
pred_embeddings_.append(pred_embeddings[start_i:end_i])
|
342 |
+
pred_embeddings = pred_embeddings_
|
343 |
+
|
344 |
+
multimask_output = False
|
345 |
+
pred_masks = []
|
346 |
+
for i in range(len(pred_embeddings)):
|
347 |
+
sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
|
348 |
+
points=None,
|
349 |
+
boxes=None,
|
350 |
+
masks=None,
|
351 |
+
text_embeds=pred_embeddings[i].unsqueeze(1),
|
352 |
+
)
|
353 |
+
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
|
354 |
+
low_res_masks, iou_predictions = self.visual_model.mask_decoder(
|
355 |
+
image_embeddings=image_embeddings[i].unsqueeze(0),
|
356 |
+
image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
|
357 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
358 |
+
dense_prompt_embeddings=dense_embeddings,
|
359 |
+
multimask_output=multimask_output,
|
360 |
+
)
|
361 |
+
pred_mask = self.visual_model.postprocess_masks(
|
362 |
+
low_res_masks,
|
363 |
+
input_size=resize_list[i],
|
364 |
+
original_size=label_list[i].shape,
|
365 |
+
)
|
366 |
+
pred_masks.append(pred_mask[:, 0])
|
367 |
+
|
368 |
+
model_output = output
|
369 |
+
gt_masks = masks_list
|
370 |
+
|
371 |
+
if inference:
|
372 |
+
return {
|
373 |
+
"pred_masks": pred_masks,
|
374 |
+
"gt_masks": gt_masks,
|
375 |
+
}
|
376 |
+
|
377 |
+
output = model_output.logits
|
378 |
+
|
379 |
+
ce_loss = model_output.loss
|
380 |
+
ce_loss = ce_loss * self.ce_loss_weight
|
381 |
+
loss = ce_loss
|
382 |
+
mask_bce_loss = 0
|
383 |
+
mask_dice_loss = 0
|
384 |
+
num_masks = 0
|
385 |
+
for batch_idx in range(len(pred_masks)):
|
386 |
+
gt_mask = gt_masks[batch_idx]
|
387 |
+
pred_mask = pred_masks[batch_idx]
|
388 |
+
|
389 |
+
assert (
|
390 |
+
gt_mask.shape[0] == pred_mask.shape[0]
|
391 |
+
), "gt_mask.shape: {}, pred_mask.shape: {}".format(
|
392 |
+
gt_mask.shape, pred_mask.shape
|
393 |
+
)
|
394 |
+
mask_bce_loss += (
|
395 |
+
sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
|
396 |
+
* gt_mask.shape[0]
|
397 |
+
)
|
398 |
+
mask_dice_loss += (
|
399 |
+
dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
|
400 |
+
* gt_mask.shape[0]
|
401 |
+
)
|
402 |
+
num_masks += gt_mask.shape[0]
|
403 |
+
|
404 |
+
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
|
405 |
+
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
|
406 |
+
mask_loss = mask_bce_loss + mask_dice_loss
|
407 |
+
|
408 |
+
loss += mask_loss
|
409 |
+
|
410 |
+
return {
|
411 |
+
"loss": loss,
|
412 |
+
"ce_loss": ce_loss,
|
413 |
+
"mask_bce_loss": mask_bce_loss,
|
414 |
+
"mask_dice_loss": mask_dice_loss,
|
415 |
+
"mask_loss": mask_loss,
|
416 |
+
}
|
417 |
+
|
418 |
+
def evaluate(
|
419 |
+
self,
|
420 |
+
images_clip,
|
421 |
+
images,
|
422 |
+
input_ids,
|
423 |
+
resize_list,
|
424 |
+
original_size_list,
|
425 |
+
max_new_tokens=32,
|
426 |
+
tokenizer=None,
|
427 |
+
):
|
428 |
+
with torch.no_grad():
|
429 |
+
outputs = self.lm.generate(
|
430 |
+
images=images_clip,
|
431 |
+
input_ids=input_ids,
|
432 |
+
max_new_tokens=max_new_tokens,
|
433 |
+
num_beams=1,
|
434 |
+
output_hidden_states=True,
|
435 |
+
return_dict_in_generate=True,
|
436 |
+
)
|
437 |
+
output_hidden_states = outputs.hidden_states[-1]
|
438 |
+
output_ids = outputs.sequences
|
439 |
+
|
440 |
+
seg_token_mask = output_ids[:, 1:] == self.seg_token_idx
|
441 |
+
|
442 |
+
hidden_states = []
|
443 |
+
|
444 |
+
assert len(self.text_hidden_fcs) == 1
|
445 |
+
hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
|
446 |
+
|
447 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
448 |
+
pred_embeddings = last_hidden_state[seg_token_mask]
|
449 |
+
|
450 |
+
seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
|
451 |
+
seg_token_offset = seg_token_counts.cumsum(-1)
|
452 |
+
seg_token_offset = torch.cat(
|
453 |
+
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
|
454 |
+
)
|
455 |
+
|
456 |
+
pred_embeddings_ = []
|
457 |
+
for i in range(len(seg_token_offset) - 1):
|
458 |
+
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
|
459 |
+
pred_embeddings_.append(pred_embeddings[start_i:end_i])
|
460 |
+
pred_embeddings = pred_embeddings_
|
461 |
+
|
462 |
+
image_embeddings = self.get_visual_embs(images)
|
463 |
+
|
464 |
+
multimask_output = False
|
465 |
+
pred_masks = []
|
466 |
+
for i in range(len(pred_embeddings)):
|
467 |
+
sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
|
468 |
+
points=None,
|
469 |
+
boxes=None,
|
470 |
+
masks=None,
|
471 |
+
text_embeds=pred_embeddings[i].unsqueeze(1),
|
472 |
+
)
|
473 |
+
|
474 |
+
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
|
475 |
+
low_res_masks, iou_predictions = self.visual_model.mask_decoder(
|
476 |
+
image_embeddings=image_embeddings[i].unsqueeze(0),
|
477 |
+
image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
|
478 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
479 |
+
dense_prompt_embeddings=dense_embeddings,
|
480 |
+
multimask_output=multimask_output,
|
481 |
+
)
|
482 |
+
|
483 |
+
pred_mask = self.visual_model.postprocess_masks(
|
484 |
+
low_res_masks,
|
485 |
+
input_size=resize_list[i],
|
486 |
+
original_size=original_size_list[i],
|
487 |
+
)
|
488 |
+
pred_masks.append(pred_mask[:, 0])
|
489 |
+
|
490 |
+
return output_ids, pred_masks
|
model/llava/conversation.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import dataclasses
|
2 |
-
from enum import
|
3 |
from typing import List, Tuple
|
4 |
|
5 |
|
6 |
class SeparatorStyle(Enum):
|
7 |
"""Different separator style."""
|
|
|
8 |
SINGLE = auto()
|
9 |
TWO = auto()
|
10 |
MPT = auto()
|
@@ -13,6 +14,7 @@ class SeparatorStyle(Enum):
|
|
13 |
@dataclasses.dataclass
|
14 |
class Conversation:
|
15 |
"""A class that keeps all conversation history."""
|
|
|
16 |
system: str
|
17 |
roles: List[str]
|
18 |
messages: List[List[str]]
|
@@ -64,33 +66,43 @@ class Conversation:
|
|
64 |
|
65 |
def get_images(self, return_pil=False):
|
66 |
images = []
|
67 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
68 |
if i % 2 == 0:
|
69 |
if type(msg) is tuple:
|
70 |
import base64
|
71 |
from io import BytesIO
|
|
|
72 |
from PIL import Image
|
|
|
73 |
msg, image, image_process_mode = msg
|
74 |
if image_process_mode == "Pad":
|
|
|
75 |
def expand2square(pil_img, background_color=(122, 116, 104)):
|
76 |
width, height = pil_img.size
|
77 |
if width == height:
|
78 |
return pil_img
|
79 |
elif width > height:
|
80 |
-
result = Image.new(
|
|
|
|
|
81 |
result.paste(pil_img, (0, (width - height) // 2))
|
82 |
return result
|
83 |
else:
|
84 |
-
result = Image.new(
|
|
|
|
|
85 |
result.paste(pil_img, ((height - width) // 2, 0))
|
86 |
return result
|
|
|
87 |
image = expand2square(image)
|
88 |
elif image_process_mode == "Crop":
|
89 |
pass
|
90 |
elif image_process_mode == "Resize":
|
91 |
image = image.resize((224, 224))
|
92 |
else:
|
93 |
-
raise ValueError(
|
|
|
|
|
94 |
max_hw, min_hw = max(image.size), min(image.size)
|
95 |
aspect_ratio = max_hw / min_hw
|
96 |
max_len, min_len = 800, 400
|
@@ -113,11 +125,12 @@ class Conversation:
|
|
113 |
|
114 |
def to_gradio_chatbot(self):
|
115 |
ret = []
|
116 |
-
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
117 |
if i % 2 == 0:
|
118 |
if type(msg) is tuple:
|
119 |
import base64
|
120 |
from io import BytesIO
|
|
|
121 |
msg, image, image_process_mode = msg
|
122 |
max_hw, min_hw = max(image.size), min(image.size)
|
123 |
aspect_ratio = max_hw / min_hw
|
@@ -135,7 +148,7 @@ class Conversation:
|
|
135 |
image.save(buffered, format="JPEG")
|
136 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
137 |
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
138 |
-
msg = msg.replace(
|
139 |
ret.append([msg, None])
|
140 |
else:
|
141 |
ret[-1][-1] = msg
|
@@ -149,14 +162,17 @@ class Conversation:
|
|
149 |
offset=self.offset,
|
150 |
sep_style=self.sep_style,
|
151 |
sep=self.sep,
|
152 |
-
sep2=self.sep2
|
|
|
153 |
|
154 |
def dict(self):
|
155 |
if len(self.get_images()) > 0:
|
156 |
return {
|
157 |
"system": self.system,
|
158 |
"roles": self.roles,
|
159 |
-
"messages": [
|
|
|
|
|
160 |
"offset": self.offset,
|
161 |
"sep": self.sep,
|
162 |
"sep2": self.sep2,
|
@@ -173,11 +189,12 @@ class Conversation:
|
|
173 |
|
174 |
conv_v1 = Conversation(
|
175 |
system="A chat between a curious human and an artificial intelligence assistant. "
|
176 |
-
|
177 |
roles=("Human", "Assistant"),
|
178 |
messages=(
|
179 |
("Human", "Give three tips for staying healthy."),
|
180 |
-
(
|
|
|
181 |
"Sure, here are three tips for staying healthy:\n"
|
182 |
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
183 |
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
@@ -191,7 +208,8 @@ conv_v1 = Conversation(
|
|
191 |
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
192 |
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
193 |
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
194 |
-
"help improve the quality of your sleep."
|
|
|
195 |
),
|
196 |
offset=2,
|
197 |
sep_style=SeparatorStyle.SINGLE,
|
@@ -200,11 +218,15 @@ conv_v1 = Conversation(
|
|
200 |
|
201 |
conv_v1_2 = Conversation(
|
202 |
system="A chat between a curious human and an artificial intelligence assistant. "
|
203 |
-
|
204 |
roles=("Human", "Assistant"),
|
205 |
messages=(
|
206 |
-
(
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
209 |
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
210 |
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
@@ -222,7 +244,8 @@ conv_v1_2 = Conversation(
|
|
222 |
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
223 |
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
224 |
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
225 |
-
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n"
|
|
|
226 |
),
|
227 |
offset=2,
|
228 |
sep_style=SeparatorStyle.SINGLE,
|
@@ -280,12 +303,12 @@ conv_bair_v1 = Conversation(
|
|
280 |
|
281 |
simple_conv = Conversation(
|
282 |
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
|
283 |
-
|
284 |
-
|
285 |
roles=("Human", "Assistant"),
|
286 |
messages=(
|
287 |
("Human", "Hi!"),
|
288 |
-
("Assistant", "Hi there! How can I help you today?\n")
|
289 |
),
|
290 |
offset=2,
|
291 |
sep_style=SeparatorStyle.SINGLE,
|
@@ -294,12 +317,12 @@ simple_conv = Conversation(
|
|
294 |
|
295 |
simple_conv_multimodal = Conversation(
|
296 |
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
297 |
-
|
298 |
-
|
299 |
roles=("Human", "Assistant"),
|
300 |
messages=(
|
301 |
("Human", "Hi!"),
|
302 |
-
("Assistant", "Hi there! How can I help you today?\n")
|
303 |
),
|
304 |
offset=2,
|
305 |
sep_style=SeparatorStyle.SINGLE,
|
@@ -321,12 +344,12 @@ simple_conv_mpt_multimodal = Conversation(
|
|
321 |
|
322 |
simple_conv_legacy = Conversation(
|
323 |
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
324 |
-
|
325 |
-
|
326 |
roles=("Human", "Assistant"),
|
327 |
messages=(
|
328 |
("Human", "Hi!\n\n### Response:"),
|
329 |
-
("Assistant", "Hi there! How can I help you today?\n")
|
330 |
),
|
331 |
offset=2,
|
332 |
sep_style=SeparatorStyle.SINGLE,
|
@@ -335,8 +358,8 @@ simple_conv_legacy = Conversation(
|
|
335 |
|
336 |
conv_llava_v1 = Conversation(
|
337 |
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
338 |
-
|
339 |
-
|
340 |
roles=("USER", "ASSISTANT"),
|
341 |
version="v1",
|
342 |
messages=(),
|
@@ -354,7 +377,6 @@ conv_templates = {
|
|
354 |
"multimodal": simple_conv_multimodal,
|
355 |
"mpt_multimodal": simple_conv_mpt_multimodal,
|
356 |
"llava_v1": conv_llava_v1,
|
357 |
-
|
358 |
# fastchat
|
359 |
"v1": conv_v1_2,
|
360 |
"bair_v1": conv_bair_v1,
|
|
|
1 |
import dataclasses
|
2 |
+
from enum import Enum, auto
|
3 |
from typing import List, Tuple
|
4 |
|
5 |
|
6 |
class SeparatorStyle(Enum):
|
7 |
"""Different separator style."""
|
8 |
+
|
9 |
SINGLE = auto()
|
10 |
TWO = auto()
|
11 |
MPT = auto()
|
|
|
14 |
@dataclasses.dataclass
|
15 |
class Conversation:
|
16 |
"""A class that keeps all conversation history."""
|
17 |
+
|
18 |
system: str
|
19 |
roles: List[str]
|
20 |
messages: List[List[str]]
|
|
|
66 |
|
67 |
def get_images(self, return_pil=False):
|
68 |
images = []
|
69 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
70 |
if i % 2 == 0:
|
71 |
if type(msg) is tuple:
|
72 |
import base64
|
73 |
from io import BytesIO
|
74 |
+
|
75 |
from PIL import Image
|
76 |
+
|
77 |
msg, image, image_process_mode = msg
|
78 |
if image_process_mode == "Pad":
|
79 |
+
|
80 |
def expand2square(pil_img, background_color=(122, 116, 104)):
|
81 |
width, height = pil_img.size
|
82 |
if width == height:
|
83 |
return pil_img
|
84 |
elif width > height:
|
85 |
+
result = Image.new(
|
86 |
+
pil_img.mode, (width, width), background_color
|
87 |
+
)
|
88 |
result.paste(pil_img, (0, (width - height) // 2))
|
89 |
return result
|
90 |
else:
|
91 |
+
result = Image.new(
|
92 |
+
pil_img.mode, (height, height), background_color
|
93 |
+
)
|
94 |
result.paste(pil_img, ((height - width) // 2, 0))
|
95 |
return result
|
96 |
+
|
97 |
image = expand2square(image)
|
98 |
elif image_process_mode == "Crop":
|
99 |
pass
|
100 |
elif image_process_mode == "Resize":
|
101 |
image = image.resize((224, 224))
|
102 |
else:
|
103 |
+
raise ValueError(
|
104 |
+
f"Invalid image_process_mode: {image_process_mode}"
|
105 |
+
)
|
106 |
max_hw, min_hw = max(image.size), min(image.size)
|
107 |
aspect_ratio = max_hw / min_hw
|
108 |
max_len, min_len = 800, 400
|
|
|
125 |
|
126 |
def to_gradio_chatbot(self):
|
127 |
ret = []
|
128 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
129 |
if i % 2 == 0:
|
130 |
if type(msg) is tuple:
|
131 |
import base64
|
132 |
from io import BytesIO
|
133 |
+
|
134 |
msg, image, image_process_mode = msg
|
135 |
max_hw, min_hw = max(image.size), min(image.size)
|
136 |
aspect_ratio = max_hw / min_hw
|
|
|
148 |
image.save(buffered, format="JPEG")
|
149 |
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
150 |
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
151 |
+
msg = msg.replace("<image>", img_str)
|
152 |
ret.append([msg, None])
|
153 |
else:
|
154 |
ret[-1][-1] = msg
|
|
|
162 |
offset=self.offset,
|
163 |
sep_style=self.sep_style,
|
164 |
sep=self.sep,
|
165 |
+
sep2=self.sep2,
|
166 |
+
)
|
167 |
|
168 |
def dict(self):
|
169 |
if len(self.get_images()) > 0:
|
170 |
return {
|
171 |
"system": self.system,
|
172 |
"roles": self.roles,
|
173 |
+
"messages": [
|
174 |
+
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
|
175 |
+
],
|
176 |
"offset": self.offset,
|
177 |
"sep": self.sep,
|
178 |
"sep2": self.sep2,
|
|
|
189 |
|
190 |
conv_v1 = Conversation(
|
191 |
system="A chat between a curious human and an artificial intelligence assistant. "
|
192 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
193 |
roles=("Human", "Assistant"),
|
194 |
messages=(
|
195 |
("Human", "Give three tips for staying healthy."),
|
196 |
+
(
|
197 |
+
"Assistant",
|
198 |
"Sure, here are three tips for staying healthy:\n"
|
199 |
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
200 |
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
|
|
208 |
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
209 |
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
210 |
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
211 |
+
"help improve the quality of your sleep.",
|
212 |
+
),
|
213 |
),
|
214 |
offset=2,
|
215 |
sep_style=SeparatorStyle.SINGLE,
|
|
|
218 |
|
219 |
conv_v1_2 = Conversation(
|
220 |
system="A chat between a curious human and an artificial intelligence assistant. "
|
221 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
222 |
roles=("Human", "Assistant"),
|
223 |
messages=(
|
224 |
+
(
|
225 |
+
"Human",
|
226 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
227 |
+
),
|
228 |
+
(
|
229 |
+
"Assistant",
|
230 |
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
231 |
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
232 |
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
|
|
244 |
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
245 |
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
246 |
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
247 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
248 |
+
),
|
249 |
),
|
250 |
offset=2,
|
251 |
sep_style=SeparatorStyle.SINGLE,
|
|
|
303 |
|
304 |
simple_conv = Conversation(
|
305 |
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
|
306 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
307 |
+
"Follow the instructions carefully.",
|
308 |
roles=("Human", "Assistant"),
|
309 |
messages=(
|
310 |
("Human", "Hi!"),
|
311 |
+
("Assistant", "Hi there! How can I help you today?\n"),
|
312 |
),
|
313 |
offset=2,
|
314 |
sep_style=SeparatorStyle.SINGLE,
|
|
|
317 |
|
318 |
simple_conv_multimodal = Conversation(
|
319 |
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
320 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
321 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
322 |
roles=("Human", "Assistant"),
|
323 |
messages=(
|
324 |
("Human", "Hi!"),
|
325 |
+
("Assistant", "Hi there! How can I help you today?\n"),
|
326 |
),
|
327 |
offset=2,
|
328 |
sep_style=SeparatorStyle.SINGLE,
|
|
|
344 |
|
345 |
simple_conv_legacy = Conversation(
|
346 |
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
347 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
348 |
+
"Follow the instructions carefully.",
|
349 |
roles=("Human", "Assistant"),
|
350 |
messages=(
|
351 |
("Human", "Hi!\n\n### Response:"),
|
352 |
+
("Assistant", "Hi there! How can I help you today?\n"),
|
353 |
),
|
354 |
offset=2,
|
355 |
sep_style=SeparatorStyle.SINGLE,
|
|
|
358 |
|
359 |
conv_llava_v1 = Conversation(
|
360 |
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
361 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
362 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
363 |
roles=("USER", "ASSISTANT"),
|
364 |
version="v1",
|
365 |
messages=(),
|
|
|
377 |
"multimodal": simple_conv_multimodal,
|
378 |
"mpt_multimodal": simple_conv_mpt_multimodal,
|
379 |
"llava_v1": conv_llava_v1,
|
|
|
380 |
# fastchat
|
381 |
"v1": conv_v1_2,
|
382 |
"bair_v1": conv_bair_v1,
|
model/llava/eval/eval_gpt_review.py
CHANGED
@@ -1,25 +1,29 @@
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
|
5 |
import openai
|
6 |
-
import tqdm
|
7 |
import ray
|
8 |
-
import
|
|
|
9 |
|
10 |
@ray.remote(num_cpus=4)
|
11 |
def get_eval(content: str, max_tokens: int):
|
12 |
while True:
|
13 |
try:
|
14 |
response = openai.ChatCompletion.create(
|
15 |
-
model=
|
16 |
-
messages=[
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
24 |
max_tokens=max_tokens,
|
25 |
)
|
@@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int):
|
|
30 |
print(e)
|
31 |
time.sleep(1)
|
32 |
|
33 |
-
print(
|
34 |
-
return response[
|
35 |
|
36 |
|
37 |
def parse_score(review):
|
38 |
try:
|
39 |
-
score_pair = review.split(
|
40 |
-
score_pair = score_pair.replace(
|
41 |
-
sp = score_pair.split(
|
42 |
if len(sp) == 2:
|
43 |
return [float(sp[0]), float(sp[1])]
|
44 |
else:
|
45 |
-
print(
|
46 |
return [-1, -1]
|
47 |
except Exception as e:
|
48 |
print(e)
|
49 |
-
print(
|
50 |
return [-1, -1]
|
51 |
|
52 |
|
53 |
-
if __name__ ==
|
54 |
-
parser = argparse.ArgumentParser(description=
|
55 |
-
parser.add_argument(
|
56 |
# parser.add_argument('-a', '--answer')
|
57 |
-
parser.add_argument(
|
58 |
-
parser.add_argument(
|
59 |
-
parser.add_argument(
|
60 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
61 |
args = parser.parse_args()
|
62 |
|
63 |
ray.init()
|
@@ -65,9 +74,9 @@ if __name__ == '__main__':
|
|
65 |
f_q = open(os.path.expanduser(args.question))
|
66 |
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
67 |
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
68 |
-
rule_dict = json.load(open(os.path.expanduser(args.rule),
|
69 |
|
70 |
-
review_file = open(f
|
71 |
|
72 |
js_list = []
|
73 |
handles = []
|
@@ -80,23 +89,28 @@ if __name__ == '__main__':
|
|
80 |
ans1 = json.loads(ans1_js)
|
81 |
ans2 = json.loads(ans2_js)
|
82 |
|
83 |
-
category = json.loads(ques_js)[
|
84 |
if category in rule_dict:
|
85 |
rule = rule_dict[category]
|
86 |
else:
|
87 |
-
rule = rule_dict[
|
88 |
-
prompt = rule[
|
89 |
-
role = rule[
|
90 |
-
content = (
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
idx += 1
|
101 |
handles.append(get_eval.remote(content, args.max_tokens))
|
102 |
# To avoid the rate limit set by OpenAI
|
@@ -105,7 +119,7 @@ if __name__ == '__main__':
|
|
105 |
reviews = ray.get(handles)
|
106 |
for idx, review in enumerate(reviews):
|
107 |
scores = parse_score(review)
|
108 |
-
js_list[idx][
|
109 |
-
js_list[idx][
|
110 |
-
review_file.write(json.dumps(js_list[idx]) +
|
111 |
review_file.close()
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
4 |
+
import time
|
5 |
|
6 |
import openai
|
|
|
7 |
import ray
|
8 |
+
import tqdm
|
9 |
+
|
10 |
|
11 |
@ray.remote(num_cpus=4)
|
12 |
def get_eval(content: str, max_tokens: int):
|
13 |
while True:
|
14 |
try:
|
15 |
response = openai.ChatCompletion.create(
|
16 |
+
model="gpt-4",
|
17 |
+
messages=[
|
18 |
+
{
|
19 |
+
"role": "system",
|
20 |
+
"content": "You are a helpful and precise assistant for checking the quality of the answer.",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"role": "user",
|
24 |
+
"content": content,
|
25 |
+
},
|
26 |
+
],
|
27 |
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
28 |
max_tokens=max_tokens,
|
29 |
)
|
|
|
34 |
print(e)
|
35 |
time.sleep(1)
|
36 |
|
37 |
+
print("success!")
|
38 |
+
return response["choices"][0]["message"]["content"]
|
39 |
|
40 |
|
41 |
def parse_score(review):
|
42 |
try:
|
43 |
+
score_pair = review.split("\n")[0]
|
44 |
+
score_pair = score_pair.replace(",", " ")
|
45 |
+
sp = score_pair.split(" ")
|
46 |
if len(sp) == 2:
|
47 |
return [float(sp[0]), float(sp[1])]
|
48 |
else:
|
49 |
+
print("error", review)
|
50 |
return [-1, -1]
|
51 |
except Exception as e:
|
52 |
print(e)
|
53 |
+
print("error", review)
|
54 |
return [-1, -1]
|
55 |
|
56 |
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.")
|
59 |
+
parser.add_argument("-q", "--question")
|
60 |
# parser.add_argument('-a', '--answer')
|
61 |
+
parser.add_argument("-a", "--answer-list", nargs="+", default=[])
|
62 |
+
parser.add_argument("-r", "--rule")
|
63 |
+
parser.add_argument("-o", "--output")
|
64 |
+
parser.add_argument(
|
65 |
+
"--max-tokens",
|
66 |
+
type=int,
|
67 |
+
default=1024,
|
68 |
+
help="maximum number of tokens produced in the output",
|
69 |
+
)
|
70 |
args = parser.parse_args()
|
71 |
|
72 |
ray.init()
|
|
|
74 |
f_q = open(os.path.expanduser(args.question))
|
75 |
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
76 |
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
77 |
+
rule_dict = json.load(open(os.path.expanduser(args.rule), "r"))
|
78 |
|
79 |
+
review_file = open(f"{args.output}", "w")
|
80 |
|
81 |
js_list = []
|
82 |
handles = []
|
|
|
89 |
ans1 = json.loads(ans1_js)
|
90 |
ans2 = json.loads(ans2_js)
|
91 |
|
92 |
+
category = json.loads(ques_js)["category"]
|
93 |
if category in rule_dict:
|
94 |
rule = rule_dict[category]
|
95 |
else:
|
96 |
+
rule = rule_dict["default"]
|
97 |
+
prompt = rule["prompt"]
|
98 |
+
role = rule["role"]
|
99 |
+
content = (
|
100 |
+
f'[Question]\n{ques["text"]}\n\n'
|
101 |
+
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
|
102 |
+
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
|
103 |
+
f"[System]\n{prompt}\n\n"
|
104 |
+
)
|
105 |
+
js_list.append(
|
106 |
+
{
|
107 |
+
"id": idx + 1,
|
108 |
+
"question_id": ques["question_id"],
|
109 |
+
"answer1_id": ans1["answer_id"],
|
110 |
+
"answer2_id": ans2["answer_id"],
|
111 |
+
"category": category,
|
112 |
+
}
|
113 |
+
)
|
114 |
idx += 1
|
115 |
handles.append(get_eval.remote(content, args.max_tokens))
|
116 |
# To avoid the rate limit set by OpenAI
|
|
|
119 |
reviews = ray.get(handles)
|
120 |
for idx, review in enumerate(reviews):
|
121 |
scores = parse_score(review)
|
122 |
+
js_list[idx]["content"] = review
|
123 |
+
js_list[idx]["tuple"] = scores
|
124 |
+
review_file.write(json.dumps(js_list[idx]) + "\n")
|
125 |
review_file.close()
|
model/llava/eval/eval_gpt_review_visual.py
CHANGED
@@ -1,25 +1,29 @@
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
|
5 |
import openai
|
6 |
-
import tqdm
|
7 |
import ray
|
8 |
-
import
|
|
|
9 |
|
10 |
@ray.remote(num_cpus=4)
|
11 |
def get_eval(content: str, max_tokens: int):
|
12 |
while True:
|
13 |
try:
|
14 |
response = openai.ChatCompletion.create(
|
15 |
-
model=
|
16 |
-
messages=[
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
24 |
max_tokens=max_tokens,
|
25 |
)
|
@@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int):
|
|
30 |
print(e)
|
31 |
time.sleep(1)
|
32 |
|
33 |
-
print(
|
34 |
-
return response[
|
35 |
|
36 |
|
37 |
def parse_score(review):
|
38 |
try:
|
39 |
-
score_pair = review.split(
|
40 |
-
score_pair = score_pair.replace(
|
41 |
-
sp = score_pair.split(
|
42 |
if len(sp) == 2:
|
43 |
return [float(sp[0]), float(sp[1])]
|
44 |
else:
|
45 |
-
print(
|
46 |
return [-1, -1]
|
47 |
except Exception as e:
|
48 |
print(e)
|
49 |
-
print(
|
50 |
return [-1, -1]
|
51 |
|
52 |
|
53 |
-
if __name__ ==
|
54 |
-
parser = argparse.ArgumentParser(description=
|
55 |
-
parser.add_argument(
|
56 |
-
parser.add_argument(
|
57 |
-
parser.add_argument(
|
58 |
-
parser.add_argument(
|
59 |
-
parser.add_argument(
|
60 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
61 |
args = parser.parse_args()
|
62 |
|
63 |
ray.init()
|
@@ -65,12 +74,12 @@ if __name__ == '__main__':
|
|
65 |
f_q = open(os.path.expanduser(args.question))
|
66 |
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
67 |
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
68 |
-
rule_dict = json.load(open(os.path.expanduser(args.rule),
|
69 |
|
70 |
-
review_file = open(f
|
71 |
|
72 |
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
|
73 |
-
image_to_context = {context[
|
74 |
|
75 |
js_list = []
|
76 |
handles = []
|
@@ -80,28 +89,38 @@ if __name__ == '__main__':
|
|
80 |
ans1 = json.loads(ans1_js)
|
81 |
ans2 = json.loads(ans2_js)
|
82 |
|
83 |
-
inst = image_to_context[ques[
|
84 |
-
cap_str =
|
85 |
-
box_str =
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
category = json.loads(ques_js)[
|
88 |
if category in rule_dict:
|
89 |
rule = rule_dict[category]
|
90 |
else:
|
91 |
assert False, f"Visual QA category not found in rule file: {category}."
|
92 |
-
prompt = rule[
|
93 |
-
role = rule[
|
94 |
-
content = (
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
105 |
idx += 1
|
106 |
handles.append(get_eval.remote(content, args.max_tokens))
|
107 |
# To avoid the rate limit set by OpenAI
|
@@ -110,7 +129,7 @@ if __name__ == '__main__':
|
|
110 |
reviews = ray.get(handles)
|
111 |
for idx, review in enumerate(reviews):
|
112 |
scores = parse_score(review)
|
113 |
-
js_list[idx][
|
114 |
-
js_list[idx][
|
115 |
-
review_file.write(json.dumps(js_list[idx]) +
|
116 |
review_file.close()
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
4 |
+
import time
|
5 |
|
6 |
import openai
|
|
|
7 |
import ray
|
8 |
+
import tqdm
|
9 |
+
|
10 |
|
11 |
@ray.remote(num_cpus=4)
|
12 |
def get_eval(content: str, max_tokens: int):
|
13 |
while True:
|
14 |
try:
|
15 |
response = openai.ChatCompletion.create(
|
16 |
+
model="gpt-4",
|
17 |
+
messages=[
|
18 |
+
{
|
19 |
+
"role": "system",
|
20 |
+
"content": "You are a helpful and precise assistant for checking the quality of the answer.",
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"role": "user",
|
24 |
+
"content": content,
|
25 |
+
},
|
26 |
+
],
|
27 |
temperature=0.2, # TODO: figure out which temperature is best for evaluation
|
28 |
max_tokens=max_tokens,
|
29 |
)
|
|
|
34 |
print(e)
|
35 |
time.sleep(1)
|
36 |
|
37 |
+
print("success!")
|
38 |
+
return response["choices"][0]["message"]["content"]
|
39 |
|
40 |
|
41 |
def parse_score(review):
|
42 |
try:
|
43 |
+
score_pair = review.split("\n")[0]
|
44 |
+
score_pair = score_pair.replace(",", " ")
|
45 |
+
sp = score_pair.split(" ")
|
46 |
if len(sp) == 2:
|
47 |
return [float(sp[0]), float(sp[1])]
|
48 |
else:
|
49 |
+
print("error", review)
|
50 |
return [-1, -1]
|
51 |
except Exception as e:
|
52 |
print(e)
|
53 |
+
print("error", review)
|
54 |
return [-1, -1]
|
55 |
|
56 |
|
57 |
+
if __name__ == "__main__":
|
58 |
+
parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.")
|
59 |
+
parser.add_argument("-q", "--question")
|
60 |
+
parser.add_argument("-c", "--context")
|
61 |
+
parser.add_argument("-a", "--answer-list", nargs="+", default=[])
|
62 |
+
parser.add_argument("-r", "--rule")
|
63 |
+
parser.add_argument("-o", "--output")
|
64 |
+
parser.add_argument(
|
65 |
+
"--max-tokens",
|
66 |
+
type=int,
|
67 |
+
default=1024,
|
68 |
+
help="maximum number of tokens produced in the output",
|
69 |
+
)
|
70 |
args = parser.parse_args()
|
71 |
|
72 |
ray.init()
|
|
|
74 |
f_q = open(os.path.expanduser(args.question))
|
75 |
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
|
76 |
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
|
77 |
+
rule_dict = json.load(open(os.path.expanduser(args.rule), "r"))
|
78 |
|
79 |
+
review_file = open(f"{args.output}", "w")
|
80 |
|
81 |
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
|
82 |
+
image_to_context = {context["image"]: context for context in context_list}
|
83 |
|
84 |
js_list = []
|
85 |
handles = []
|
|
|
89 |
ans1 = json.loads(ans1_js)
|
90 |
ans2 = json.loads(ans2_js)
|
91 |
|
92 |
+
inst = image_to_context[ques["image"]]
|
93 |
+
cap_str = "\n".join(inst["captions"])
|
94 |
+
box_str = "\n".join(
|
95 |
+
[
|
96 |
+
f'{instance["category"]}: {instance["bbox"]}'
|
97 |
+
for instance in inst["instances"]
|
98 |
+
]
|
99 |
+
)
|
100 |
|
101 |
+
category = json.loads(ques_js)["category"]
|
102 |
if category in rule_dict:
|
103 |
rule = rule_dict[category]
|
104 |
else:
|
105 |
assert False, f"Visual QA category not found in rule file: {category}."
|
106 |
+
prompt = rule["prompt"]
|
107 |
+
role = rule["role"]
|
108 |
+
content = (
|
109 |
+
f"[Context]\n{cap_str}\n\n{box_str}\n\n"
|
110 |
+
f'[Question]\n{ques["text"]}\n\n'
|
111 |
+
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
|
112 |
+
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
|
113 |
+
f"[System]\n{prompt}\n\n"
|
114 |
+
)
|
115 |
+
js_list.append(
|
116 |
+
{
|
117 |
+
"id": idx + 1,
|
118 |
+
"question_id": ques["question_id"],
|
119 |
+
"answer1_id": ans1.get("answer_id", ans1["question_id"]),
|
120 |
+
"answer2_id": ans2.get("answer_id", ans2["answer_id"]),
|
121 |
+
"category": category,
|
122 |
+
}
|
123 |
+
)
|
124 |
idx += 1
|
125 |
handles.append(get_eval.remote(content, args.max_tokens))
|
126 |
# To avoid the rate limit set by OpenAI
|
|
|
129 |
reviews = ray.get(handles)
|
130 |
for idx, review in enumerate(reviews):
|
131 |
scores = parse_score(review)
|
132 |
+
js_list[idx]["content"] = review
|
133 |
+
js_list[idx]["tuple"] = scores
|
134 |
+
review_file.write(json.dumps(js_list[idx]) + "\n")
|
135 |
review_file.close()
|
model/llava/eval/eval_science_qa.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
4 |
-
import re
|
5 |
import random
|
|
|
6 |
|
7 |
|
8 |
def get_args():
|
9 |
parser = argparse.ArgumentParser()
|
10 |
-
parser.add_argument(
|
11 |
-
parser.add_argument(
|
12 |
-
parser.add_argument(
|
13 |
-
parser.add_argument(
|
14 |
-
parser.add_argument(
|
15 |
-
parser.add_argument(
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
19 |
def convert_caps(results):
|
20 |
fakecaps = []
|
21 |
for result in results:
|
22 |
-
image_id = result[
|
23 |
-
caption = result[
|
24 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
return fakecaps
|
26 |
|
@@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options):
|
|
29 |
"""
|
30 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
"""
|
32 |
-
if prediction in options[:len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
35 |
return random.choice(range(len(choices)))
|
@@ -39,61 +39,65 @@ if __name__ == "__main__":
|
|
39 |
args = get_args()
|
40 |
|
41 |
base_dir = args.base_dir
|
42 |
-
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
|
|
|
|
43 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
44 |
predictions = [json.loads(line) for line in open(args.result_file)]
|
45 |
-
predictions = {pred[
|
46 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
47 |
|
48 |
-
results = {
|
49 |
sqa_results = {}
|
50 |
-
sqa_results[
|
51 |
-
sqa_results[
|
52 |
-
sqa_results[
|
53 |
-
sqa_results[
|
54 |
-
sqa_results[
|
55 |
|
56 |
for prob_id, prob in split_problems.items():
|
57 |
if prob_id not in predictions:
|
58 |
continue
|
59 |
pred = predictions[prob_id]
|
60 |
-
pred_text = pred[
|
61 |
|
62 |
-
pattern = re.compile(r
|
63 |
res = pattern.findall(pred_text)
|
64 |
if len(res) == 1:
|
65 |
answer = res[0] # 'A', 'B', ...
|
66 |
else:
|
67 |
answer = "FAILED"
|
68 |
|
69 |
-
pred_idx = get_pred_idx(answer, prob[
|
70 |
|
71 |
analysis = {
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
}
|
79 |
|
80 |
-
sqa_results[
|
81 |
-
|
|
|
|
|
82 |
|
83 |
-
if pred_idx == prob[
|
84 |
-
results[
|
85 |
else:
|
86 |
-
results[
|
87 |
|
88 |
-
correct = len(results[
|
89 |
-
total = len(results[
|
90 |
-
print(f
|
91 |
|
92 |
-
sqa_results[
|
93 |
-
sqa_results[
|
94 |
-
sqa_results[
|
95 |
|
96 |
-
with open(args.output_file,
|
97 |
json.dump(results, f, indent=2)
|
98 |
-
with open(args.output_result,
|
99 |
json.dump(sqa_results, f, indent=2)
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
import random
|
5 |
+
import re
|
6 |
|
7 |
|
8 |
def get_args():
|
9 |
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--base-dir", type=str)
|
11 |
+
parser.add_argument("--result-file", type=str)
|
12 |
+
parser.add_argument("--output-file", type=str)
|
13 |
+
parser.add_argument("--output-result", type=str)
|
14 |
+
parser.add_argument("--split", type=str, default="test")
|
15 |
+
parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
19 |
def convert_caps(results):
|
20 |
fakecaps = []
|
21 |
for result in results:
|
22 |
+
image_id = result["question_id"]
|
23 |
+
caption = result["text"]
|
24 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
return fakecaps
|
26 |
|
|
|
29 |
"""
|
30 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
"""
|
32 |
+
if prediction in options[: len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
35 |
return random.choice(range(len(choices)))
|
|
|
39 |
args = get_args()
|
40 |
|
41 |
base_dir = args.base_dir
|
42 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
43 |
+
args.split
|
44 |
+
]
|
45 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
46 |
predictions = [json.loads(line) for line in open(args.result_file)]
|
47 |
+
predictions = {pred["question_id"]: pred for pred in predictions}
|
48 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
49 |
|
50 |
+
results = {"correct": [], "incorrect": []}
|
51 |
sqa_results = {}
|
52 |
+
sqa_results["acc"] = None
|
53 |
+
sqa_results["correct"] = None
|
54 |
+
sqa_results["count"] = None
|
55 |
+
sqa_results["results"] = {}
|
56 |
+
sqa_results["outputs"] = {}
|
57 |
|
58 |
for prob_id, prob in split_problems.items():
|
59 |
if prob_id not in predictions:
|
60 |
continue
|
61 |
pred = predictions[prob_id]
|
62 |
+
pred_text = pred["text"]
|
63 |
|
64 |
+
pattern = re.compile(r"The answer is ([A-Z]).")
|
65 |
res = pattern.findall(pred_text)
|
66 |
if len(res) == 1:
|
67 |
answer = res[0] # 'A', 'B', ...
|
68 |
else:
|
69 |
answer = "FAILED"
|
70 |
|
71 |
+
pred_idx = get_pred_idx(answer, prob["choices"], args.options)
|
72 |
|
73 |
analysis = {
|
74 |
+
"question_id": prob_id,
|
75 |
+
"parsed_ans": answer,
|
76 |
+
"ground_truth": args.options[prob["answer"]],
|
77 |
+
"question": pred["prompt"],
|
78 |
+
"pred": pred_text,
|
79 |
+
"is_multimodal": "<image>" in pred["prompt"],
|
80 |
}
|
81 |
|
82 |
+
sqa_results["results"][prob_id] = get_pred_idx(
|
83 |
+
answer, prob["choices"], args.options
|
84 |
+
)
|
85 |
+
sqa_results["outputs"][prob_id] = pred_text
|
86 |
|
87 |
+
if pred_idx == prob["answer"]:
|
88 |
+
results["correct"].append(analysis)
|
89 |
else:
|
90 |
+
results["incorrect"].append(analysis)
|
91 |
|
92 |
+
correct = len(results["correct"])
|
93 |
+
total = len(results["correct"]) + len(results["incorrect"])
|
94 |
+
print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
|
95 |
|
96 |
+
sqa_results["acc"] = correct / total * 100
|
97 |
+
sqa_results["correct"] = correct
|
98 |
+
sqa_results["count"] = total
|
99 |
|
100 |
+
with open(args.output_file, "w") as f:
|
101 |
json.dump(results, f, indent=2)
|
102 |
+
with open(args.output_result, "w") as f:
|
103 |
json.dump(sqa_results, f, indent=2)
|
model/llava/eval/eval_science_qa_gpt4.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
4 |
-
import re
|
5 |
import random
|
|
|
6 |
from collections import defaultdict
|
7 |
|
8 |
|
9 |
def get_args():
|
10 |
parser = argparse.ArgumentParser()
|
11 |
-
parser.add_argument(
|
12 |
-
parser.add_argument(
|
13 |
-
parser.add_argument(
|
14 |
-
parser.add_argument(
|
15 |
-
parser.add_argument(
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
19 |
def convert_caps(results):
|
20 |
fakecaps = []
|
21 |
for result in results:
|
22 |
-
image_id = result[
|
23 |
-
caption = result[
|
24 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
return fakecaps
|
26 |
|
@@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options):
|
|
29 |
"""
|
30 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
"""
|
32 |
-
if prediction in options[:len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
35 |
return random.choice(range(len(choices)))
|
@@ -39,13 +39,15 @@ if __name__ == "__main__":
|
|
39 |
args = get_args()
|
40 |
|
41 |
base_dir = args.base_dir
|
42 |
-
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
|
|
|
|
43 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
44 |
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
45 |
-
our_predictions = {pred[
|
46 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
47 |
|
48 |
-
gpt4_predictions = json.load(open(args.gpt4_result))[
|
49 |
|
50 |
results = defaultdict(lambda: 0)
|
51 |
|
@@ -54,10 +56,10 @@ if __name__ == "__main__":
|
|
54 |
continue
|
55 |
if prob_id not in gpt4_predictions:
|
56 |
continue
|
57 |
-
our_pred = our_predictions[prob_id][
|
58 |
gpt4_pred = gpt4_predictions[prob_id]
|
59 |
|
60 |
-
pattern = re.compile(r
|
61 |
our_res = pattern.findall(our_pred)
|
62 |
if len(our_res) == 1:
|
63 |
our_answer = our_res[0] # 'A', 'B', ...
|
@@ -69,11 +71,11 @@ if __name__ == "__main__":
|
|
69 |
else:
|
70 |
gpt4_answer = "FAILED"
|
71 |
|
72 |
-
our_pred_idx = get_pred_idx(our_answer, prob[
|
73 |
-
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob[
|
74 |
|
75 |
-
if gpt4_answer ==
|
76 |
-
results[
|
77 |
# continue
|
78 |
gpt4_pred_idx = our_pred_idx
|
79 |
# if our_pred_idx != prob['answer']:
|
@@ -87,18 +89,20 @@ if __name__ == "__main__":
|
|
87 |
pass
|
88 |
# gpt4_pred_idx = our_pred_idx
|
89 |
|
90 |
-
if gpt4_pred_idx == prob[
|
91 |
-
results[
|
92 |
else:
|
93 |
-
results[
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
total
|
101 |
-
print(
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
import random
|
5 |
+
import re
|
6 |
from collections import defaultdict
|
7 |
|
8 |
|
9 |
def get_args():
|
10 |
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--base-dir", type=str)
|
12 |
+
parser.add_argument("--gpt4-result", type=str)
|
13 |
+
parser.add_argument("--our-result", type=str)
|
14 |
+
parser.add_argument("--split", type=str, default="test")
|
15 |
+
parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
|
16 |
return parser.parse_args()
|
17 |
|
18 |
|
19 |
def convert_caps(results):
|
20 |
fakecaps = []
|
21 |
for result in results:
|
22 |
+
image_id = result["question_id"]
|
23 |
+
caption = result["text"]
|
24 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
25 |
return fakecaps
|
26 |
|
|
|
29 |
"""
|
30 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
31 |
"""
|
32 |
+
if prediction in options[: len(choices)]:
|
33 |
return options.index(prediction)
|
34 |
else:
|
35 |
return random.choice(range(len(choices)))
|
|
|
39 |
args = get_args()
|
40 |
|
41 |
base_dir = args.base_dir
|
42 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
43 |
+
args.split
|
44 |
+
]
|
45 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
46 |
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
47 |
+
our_predictions = {pred["question_id"]: pred for pred in our_predictions}
|
48 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
49 |
|
50 |
+
gpt4_predictions = json.load(open(args.gpt4_result))["outputs"]
|
51 |
|
52 |
results = defaultdict(lambda: 0)
|
53 |
|
|
|
56 |
continue
|
57 |
if prob_id not in gpt4_predictions:
|
58 |
continue
|
59 |
+
our_pred = our_predictions[prob_id]["text"]
|
60 |
gpt4_pred = gpt4_predictions[prob_id]
|
61 |
|
62 |
+
pattern = re.compile(r"The answer is ([A-Z]).")
|
63 |
our_res = pattern.findall(our_pred)
|
64 |
if len(our_res) == 1:
|
65 |
our_answer = our_res[0] # 'A', 'B', ...
|
|
|
71 |
else:
|
72 |
gpt4_answer = "FAILED"
|
73 |
|
74 |
+
our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options)
|
75 |
+
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options)
|
76 |
|
77 |
+
if gpt4_answer == "FAILED":
|
78 |
+
results["gpt4_failed"] += 1
|
79 |
# continue
|
80 |
gpt4_pred_idx = our_pred_idx
|
81 |
# if our_pred_idx != prob['answer']:
|
|
|
89 |
pass
|
90 |
# gpt4_pred_idx = our_pred_idx
|
91 |
|
92 |
+
if gpt4_pred_idx == prob["answer"]:
|
93 |
+
results["correct"] += 1
|
94 |
else:
|
95 |
+
results["incorrect"] += 1
|
96 |
+
|
97 |
+
if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]:
|
98 |
+
results["correct_upperbound"] += 1
|
99 |
+
|
100 |
+
correct = results["correct"]
|
101 |
+
total = results["correct"] + results["incorrect"]
|
102 |
+
print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
|
103 |
+
print(
|
104 |
+
f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%'
|
105 |
+
)
|
106 |
+
print(
|
107 |
+
f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%'
|
108 |
+
)
|
model/llava/eval/eval_science_qa_gpt4_requery.py
CHANGED
@@ -1,28 +1,28 @@
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
4 |
-
import re
|
5 |
import random
|
|
|
6 |
from collections import defaultdict
|
7 |
|
8 |
|
9 |
def get_args():
|
10 |
parser = argparse.ArgumentParser()
|
11 |
-
parser.add_argument(
|
12 |
-
parser.add_argument(
|
13 |
-
parser.add_argument(
|
14 |
-
parser.add_argument(
|
15 |
-
parser.add_argument(
|
16 |
-
parser.add_argument(
|
17 |
-
parser.add_argument(
|
18 |
return parser.parse_args()
|
19 |
|
20 |
|
21 |
def convert_caps(results):
|
22 |
fakecaps = []
|
23 |
for result in results:
|
24 |
-
image_id = result[
|
25 |
-
caption = result[
|
26 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
27 |
return fakecaps
|
28 |
|
@@ -31,7 +31,7 @@ def get_pred_idx(prediction, choices, options):
|
|
31 |
"""
|
32 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
33 |
"""
|
34 |
-
if prediction in options[:len(choices)]:
|
35 |
return options.index(prediction)
|
36 |
else:
|
37 |
return random.choice(range(len(choices)))
|
@@ -41,40 +41,42 @@ if __name__ == "__main__":
|
|
41 |
args = get_args()
|
42 |
|
43 |
base_dir = args.base_dir
|
44 |
-
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
|
|
|
|
45 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
46 |
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
47 |
-
our_predictions = {pred[
|
48 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
49 |
|
50 |
requery_predictions = [json.loads(line) for line in open(args.requery_result)]
|
51 |
-
requery_predictions = {pred[
|
52 |
|
53 |
-
gpt4_predictions = json.load(open(args.gpt4_result))[
|
54 |
|
55 |
results = defaultdict(lambda: 0)
|
56 |
|
57 |
sqa_results = {}
|
58 |
-
sqa_results[
|
59 |
-
sqa_results[
|
60 |
-
sqa_results[
|
61 |
-
sqa_results[
|
62 |
-
sqa_results[
|
63 |
|
64 |
for prob_id, prob in split_problems.items():
|
65 |
if prob_id not in our_predictions:
|
66 |
assert False
|
67 |
if prob_id not in gpt4_predictions:
|
68 |
assert False
|
69 |
-
our_pred = our_predictions[prob_id][
|
70 |
gpt4_pred = gpt4_predictions[prob_id]
|
71 |
if prob_id not in requery_predictions:
|
72 |
-
results[
|
73 |
requery_pred = "MISSING"
|
74 |
else:
|
75 |
-
requery_pred = requery_predictions[prob_id][
|
76 |
|
77 |
-
pattern = re.compile(r
|
78 |
our_res = pattern.findall(our_pred)
|
79 |
if len(our_res) == 1:
|
80 |
our_answer = our_res[0] # 'A', 'B', ...
|
@@ -93,57 +95,70 @@ if __name__ == "__main__":
|
|
93 |
else:
|
94 |
gpt4_answer = "FAILED"
|
95 |
|
96 |
-
our_pred_idx = get_pred_idx(our_answer, prob[
|
97 |
-
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob[
|
98 |
-
requery_pred_idx = get_pred_idx(requery_answer, prob[
|
99 |
-
|
100 |
-
results[
|
101 |
-
|
102 |
-
if gpt4_answer ==
|
103 |
-
results[
|
104 |
-
if gpt4_pred_idx == prob[
|
105 |
-
results[
|
106 |
-
if our_pred_idx == prob[
|
107 |
-
results[
|
108 |
-
elif gpt4_pred_idx == prob[
|
109 |
-
results[
|
110 |
-
results[
|
111 |
-
|
112 |
-
if our_pred_idx == prob[
|
113 |
-
results[
|
114 |
-
|
115 |
-
if requery_answer ==
|
116 |
-
sqa_results[
|
117 |
-
if our_pred_idx == prob[
|
118 |
-
results[
|
119 |
else:
|
120 |
-
sqa_results[
|
121 |
-
if requery_pred_idx == prob[
|
122 |
-
results[
|
123 |
else:
|
124 |
-
print(
|
|
|
125 |
Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
|
126 |
Our ({our_answer}): {our_pred}
|
127 |
GPT-4 ({gpt4_answer}): {gpt4_pred}
|
128 |
Requery ({requery_answer}): {requery_pred}
|
129 |
print("=====================================")
|
130 |
-
"""
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
print(
|
138 |
-
|
139 |
-
|
140 |
-
print(
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
json.dump(sqa_results, f, indent=2)
|
149 |
-
|
|
|
1 |
import argparse
|
2 |
import json
|
3 |
import os
|
|
|
4 |
import random
|
5 |
+
import re
|
6 |
from collections import defaultdict
|
7 |
|
8 |
|
9 |
def get_args():
|
10 |
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--base-dir", type=str)
|
12 |
+
parser.add_argument("--gpt4-result", type=str)
|
13 |
+
parser.add_argument("--requery-result", type=str)
|
14 |
+
parser.add_argument("--our-result", type=str)
|
15 |
+
parser.add_argument("--output-result", type=str)
|
16 |
+
parser.add_argument("--split", type=str, default="test")
|
17 |
+
parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
|
18 |
return parser.parse_args()
|
19 |
|
20 |
|
21 |
def convert_caps(results):
|
22 |
fakecaps = []
|
23 |
for result in results:
|
24 |
+
image_id = result["question_id"]
|
25 |
+
caption = result["text"]
|
26 |
fakecaps.append({"image_id": int(image_id), "caption": caption})
|
27 |
return fakecaps
|
28 |
|
|
|
31 |
"""
|
32 |
Get the index (e.g. 2) from the prediction (e.g. 'C')
|
33 |
"""
|
34 |
+
if prediction in options[: len(choices)]:
|
35 |
return options.index(prediction)
|
36 |
else:
|
37 |
return random.choice(range(len(choices)))
|
|
|
41 |
args = get_args()
|
42 |
|
43 |
base_dir = args.base_dir
|
44 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
|
45 |
+
args.split
|
46 |
+
]
|
47 |
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
48 |
our_predictions = [json.loads(line) for line in open(args.our_result)]
|
49 |
+
our_predictions = {pred["question_id"]: pred for pred in our_predictions}
|
50 |
split_problems = {idx: problems[idx] for idx in split_indices}
|
51 |
|
52 |
requery_predictions = [json.loads(line) for line in open(args.requery_result)]
|
53 |
+
requery_predictions = {pred["question_id"]: pred for pred in requery_predictions}
|
54 |
|
55 |
+
gpt4_predictions = json.load(open(args.gpt4_result))["outputs"]
|
56 |
|
57 |
results = defaultdict(lambda: 0)
|
58 |
|
59 |
sqa_results = {}
|
60 |
+
sqa_results["acc"] = None
|
61 |
+
sqa_results["correct"] = None
|
62 |
+
sqa_results["count"] = None
|
63 |
+
sqa_results["results"] = {}
|
64 |
+
sqa_results["outputs"] = {}
|
65 |
|
66 |
for prob_id, prob in split_problems.items():
|
67 |
if prob_id not in our_predictions:
|
68 |
assert False
|
69 |
if prob_id not in gpt4_predictions:
|
70 |
assert False
|
71 |
+
our_pred = our_predictions[prob_id]["text"]
|
72 |
gpt4_pred = gpt4_predictions[prob_id]
|
73 |
if prob_id not in requery_predictions:
|
74 |
+
results["missing_requery"] += 1
|
75 |
requery_pred = "MISSING"
|
76 |
else:
|
77 |
+
requery_pred = requery_predictions[prob_id]["text"]
|
78 |
|
79 |
+
pattern = re.compile(r"The answer is ([A-Z]).")
|
80 |
our_res = pattern.findall(our_pred)
|
81 |
if len(our_res) == 1:
|
82 |
our_answer = our_res[0] # 'A', 'B', ...
|
|
|
95 |
else:
|
96 |
gpt4_answer = "FAILED"
|
97 |
|
98 |
+
our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options)
|
99 |
+
gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options)
|
100 |
+
requery_pred_idx = get_pred_idx(requery_answer, prob["choices"], args.options)
|
101 |
+
|
102 |
+
results["total"] += 1
|
103 |
+
|
104 |
+
if gpt4_answer == "FAILED":
|
105 |
+
results["gpt4_failed"] += 1
|
106 |
+
if gpt4_pred_idx == prob["answer"]:
|
107 |
+
results["gpt4_correct"] += 1
|
108 |
+
if our_pred_idx == prob["answer"]:
|
109 |
+
results["gpt4_ourvisual_correct"] += 1
|
110 |
+
elif gpt4_pred_idx == prob["answer"]:
|
111 |
+
results["gpt4_correct"] += 1
|
112 |
+
results["gpt4_ourvisual_correct"] += 1
|
113 |
+
|
114 |
+
if our_pred_idx == prob["answer"]:
|
115 |
+
results["our_correct"] += 1
|
116 |
+
|
117 |
+
if requery_answer == "FAILED":
|
118 |
+
sqa_results["results"][prob_id] = our_pred_idx
|
119 |
+
if our_pred_idx == prob["answer"]:
|
120 |
+
results["requery_correct"] += 1
|
121 |
else:
|
122 |
+
sqa_results["results"][prob_id] = requery_pred_idx
|
123 |
+
if requery_pred_idx == prob["answer"]:
|
124 |
+
results["requery_correct"] += 1
|
125 |
else:
|
126 |
+
print(
|
127 |
+
f"""
|
128 |
Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
|
129 |
Our ({our_answer}): {our_pred}
|
130 |
GPT-4 ({gpt4_answer}): {gpt4_pred}
|
131 |
Requery ({requery_answer}): {requery_pred}
|
132 |
print("=====================================")
|
133 |
+
"""
|
134 |
+
)
|
135 |
+
|
136 |
+
if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]:
|
137 |
+
results["correct_upperbound"] += 1
|
138 |
+
|
139 |
+
total = results["total"]
|
140 |
+
print(
|
141 |
+
f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%'
|
142 |
+
)
|
143 |
+
print(
|
144 |
+
f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%'
|
145 |
+
)
|
146 |
+
print(
|
147 |
+
f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%'
|
148 |
+
)
|
149 |
+
print(
|
150 |
+
f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%'
|
151 |
+
)
|
152 |
+
print(
|
153 |
+
f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%'
|
154 |
+
)
|
155 |
+
print(
|
156 |
+
f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%'
|
157 |
+
)
|
158 |
+
|
159 |
+
sqa_results["acc"] = results["requery_correct"] / total * 100
|
160 |
+
sqa_results["correct"] = results["requery_correct"]
|
161 |
+
sqa_results["count"] = total
|
162 |
+
|
163 |
+
with open(args.output_result, "w") as f:
|
164 |
json.dump(sqa_results, f, indent=2)
|
|
model/llava/eval/generate_webpage_data_from_table.py
CHANGED
@@ -4,10 +4,10 @@ import os
|
|
4 |
import re
|
5 |
|
6 |
# models = ['llama', 'alpaca', 'gpt35', 'bard']
|
7 |
-
models = [
|
8 |
|
9 |
|
10 |
-
def read_jsonl(path: str, key: str=None):
|
11 |
data = []
|
12 |
with open(os.path.expanduser(path)) as f:
|
13 |
for line in f:
|
@@ -23,21 +23,27 @@ def read_jsonl(path: str, key: str=None):
|
|
23 |
def trim_hanging_lines(s: str, n: int) -> str:
|
24 |
s = s.strip()
|
25 |
for _ in range(n):
|
26 |
-
s = s.split(
|
27 |
return s
|
28 |
|
29 |
|
30 |
-
if __name__ ==
|
31 |
-
questions = read_jsonl(
|
32 |
|
33 |
# alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
|
34 |
# bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
|
35 |
# gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
|
36 |
# llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
|
37 |
-
vicuna_answers = read_jsonl(
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
review_vicuna = read_jsonl(
|
|
|
|
|
41 |
# review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
|
42 |
# review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
|
43 |
# review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
|
@@ -46,26 +52,26 @@ if __name__ == '__main__':
|
|
46 |
records = []
|
47 |
for qid in questions.keys():
|
48 |
r = {
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
# 'alpaca': alpaca_answers[qid]['text'],
|
54 |
# 'llama': llama_answers[qid]['text'],
|
55 |
# 'bard': bard_answers[qid]['text'],
|
56 |
# 'gpt35': gpt35_answers[qid]['text'],
|
57 |
-
|
58 |
-
|
59 |
},
|
60 |
-
|
61 |
# 'alpaca': review_alpaca[qid]['text'],
|
62 |
# 'llama': review_llama[qid]['text'],
|
63 |
# 'bard': review_bard[qid]['text'],
|
64 |
-
|
65 |
# 'gpt35': review_gpt35[qid]['text'],
|
66 |
},
|
67 |
-
|
68 |
-
|
69 |
# 'alpaca': review_alpaca[qid]['score'],
|
70 |
# 'llama': review_llama[qid]['score'],
|
71 |
# 'bard': review_bard[qid]['score'],
|
@@ -75,37 +81,39 @@ if __name__ == '__main__':
|
|
75 |
|
76 |
# cleanup data
|
77 |
cleaned_evals = {}
|
78 |
-
for k, v in r[
|
79 |
v = v.strip()
|
80 |
-
lines = v.split(
|
81 |
# trim the first line if it's a pair of numbers
|
82 |
-
if re.match(r
|
83 |
lines = lines[1:]
|
84 |
-
v =
|
85 |
-
cleaned_evals[k] = v.replace(
|
|
|
|
|
86 |
|
87 |
-
r[
|
88 |
records.append(r)
|
89 |
|
90 |
# Reorder the records, this is optional
|
91 |
for r in records:
|
92 |
-
if r[
|
93 |
-
r[
|
94 |
else:
|
95 |
-
r[
|
96 |
for r in records:
|
97 |
-
if r[
|
98 |
-
r[
|
99 |
-
elif 50 < r[
|
100 |
-
r[
|
101 |
for r in records:
|
102 |
-
if r[
|
103 |
-
r[
|
104 |
-
elif r[
|
105 |
-
r[
|
106 |
|
107 |
-
records.sort(key=lambda x: x[
|
108 |
|
109 |
# Write to file
|
110 |
-
with open(
|
111 |
-
json.dump({
|
|
|
4 |
import re
|
5 |
|
6 |
# models = ['llama', 'alpaca', 'gpt35', 'bard']
|
7 |
+
models = ["vicuna"]
|
8 |
|
9 |
|
10 |
+
def read_jsonl(path: str, key: str = None):
|
11 |
data = []
|
12 |
with open(os.path.expanduser(path)) as f:
|
13 |
for line in f:
|
|
|
23 |
def trim_hanging_lines(s: str, n: int) -> str:
|
24 |
s = s.strip()
|
25 |
for _ in range(n):
|
26 |
+
s = s.split("\n", 1)[1].strip()
|
27 |
return s
|
28 |
|
29 |
|
30 |
+
if __name__ == "__main__":
|
31 |
+
questions = read_jsonl("table/question.jsonl", key="question_id")
|
32 |
|
33 |
# alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
|
34 |
# bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
|
35 |
# gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
|
36 |
# llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
|
37 |
+
vicuna_answers = read_jsonl(
|
38 |
+
"table/answer/answer_vicuna-13b.jsonl", key="question_id"
|
39 |
+
)
|
40 |
+
ours_answers = read_jsonl(
|
41 |
+
"table/results/llama-13b-hf-alpaca.jsonl", key="question_id"
|
42 |
+
)
|
43 |
|
44 |
+
review_vicuna = read_jsonl(
|
45 |
+
"table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl", key="question_id"
|
46 |
+
)
|
47 |
# review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
|
48 |
# review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
|
49 |
# review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
|
|
|
52 |
records = []
|
53 |
for qid in questions.keys():
|
54 |
r = {
|
55 |
+
"id": qid,
|
56 |
+
"category": questions[qid]["category"],
|
57 |
+
"question": questions[qid]["text"],
|
58 |
+
"answers": {
|
59 |
# 'alpaca': alpaca_answers[qid]['text'],
|
60 |
# 'llama': llama_answers[qid]['text'],
|
61 |
# 'bard': bard_answers[qid]['text'],
|
62 |
# 'gpt35': gpt35_answers[qid]['text'],
|
63 |
+
"vicuna": vicuna_answers[qid]["text"],
|
64 |
+
"ours": ours_answers[qid]["text"],
|
65 |
},
|
66 |
+
"evaluations": {
|
67 |
# 'alpaca': review_alpaca[qid]['text'],
|
68 |
# 'llama': review_llama[qid]['text'],
|
69 |
# 'bard': review_bard[qid]['text'],
|
70 |
+
"vicuna": review_vicuna[qid]["content"],
|
71 |
# 'gpt35': review_gpt35[qid]['text'],
|
72 |
},
|
73 |
+
"scores": {
|
74 |
+
"vicuna": review_vicuna[qid]["tuple"],
|
75 |
# 'alpaca': review_alpaca[qid]['score'],
|
76 |
# 'llama': review_llama[qid]['score'],
|
77 |
# 'bard': review_bard[qid]['score'],
|
|
|
81 |
|
82 |
# cleanup data
|
83 |
cleaned_evals = {}
|
84 |
+
for k, v in r["evaluations"].items():
|
85 |
v = v.strip()
|
86 |
+
lines = v.split("\n")
|
87 |
# trim the first line if it's a pair of numbers
|
88 |
+
if re.match(r"\d+[, ]+\d+", lines[0]):
|
89 |
lines = lines[1:]
|
90 |
+
v = "\n".join(lines)
|
91 |
+
cleaned_evals[k] = v.replace("Assistant 1", "**Assistant 1**").replace(
|
92 |
+
"Assistant 2", "**Assistant 2**"
|
93 |
+
)
|
94 |
|
95 |
+
r["evaluations"] = cleaned_evals
|
96 |
records.append(r)
|
97 |
|
98 |
# Reorder the records, this is optional
|
99 |
for r in records:
|
100 |
+
if r["id"] <= 20:
|
101 |
+
r["id"] += 60
|
102 |
else:
|
103 |
+
r["id"] -= 20
|
104 |
for r in records:
|
105 |
+
if r["id"] <= 50:
|
106 |
+
r["id"] += 10
|
107 |
+
elif 50 < r["id"] <= 60:
|
108 |
+
r["id"] -= 50
|
109 |
for r in records:
|
110 |
+
if r["id"] == 7:
|
111 |
+
r["id"] = 1
|
112 |
+
elif r["id"] < 7:
|
113 |
+
r["id"] += 1
|
114 |
|
115 |
+
records.sort(key=lambda x: x["id"])
|
116 |
|
117 |
# Write to file
|
118 |
+
with open("webpage/data.json", "w") as f:
|
119 |
+
json.dump({"questions": records, "models": models}, f, indent=2)
|
model/llava/eval/model_qa.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import argparse
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
|
3 |
-
import torch
|
4 |
-
import os
|
5 |
import json
|
6 |
-
|
7 |
-
import shortuuid
|
8 |
|
|
|
|
|
9 |
from llava.conversation import default_conversation
|
10 |
from llava.utils import disable_torch_init
|
|
|
|
|
11 |
|
12 |
|
13 |
# new stopping implementation
|
@@ -18,11 +18,15 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
18 |
self.start_len = None
|
19 |
self.input_ids = input_ids
|
20 |
|
21 |
-
def __call__(
|
|
|
|
|
22 |
if self.start_len is None:
|
23 |
self.start_len = self.input_ids.shape[1]
|
24 |
else:
|
25 |
-
outputs = self.tokenizer.batch_decode(
|
|
|
|
|
26 |
for keyword in self.keywords:
|
27 |
if keyword in outputs:
|
28 |
return True
|
@@ -35,9 +39,9 @@ def eval_model(model_name, questions_file, answers_file):
|
|
35 |
disable_torch_init()
|
36 |
model_name = os.path.expanduser(model_name)
|
37 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
38 |
-
model = AutoModelForCausalLM.from_pretrained(
|
39 |
-
torch_dtype=torch.float16
|
40 |
-
|
41 |
|
42 |
ques_file = open(os.path.expanduser(questions_file), "r")
|
43 |
ans_file = open(os.path.expanduser(answers_file), "w")
|
@@ -56,7 +60,8 @@ def eval_model(model_name, questions_file, answers_file):
|
|
56 |
do_sample=True,
|
57 |
temperature=0.7,
|
58 |
max_new_tokens=1024,
|
59 |
-
stopping_criteria=[stopping_criteria]
|
|
|
60 |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
61 |
try:
|
62 |
index = outputs.index(conv.sep, len(prompt))
|
@@ -64,16 +69,24 @@ def eval_model(model_name, questions_file, answers_file):
|
|
64 |
outputs += conv.sep
|
65 |
index = outputs.index(conv.sep, len(prompt))
|
66 |
|
67 |
-
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
|
68 |
ans_id = shortuuid.uuid()
|
69 |
-
ans_file.write(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
ans_file.flush()
|
75 |
ans_file.close()
|
76 |
|
|
|
77 |
if __name__ == "__main__":
|
78 |
parser = argparse.ArgumentParser()
|
79 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
|
|
1 |
import argparse
|
|
|
|
|
|
|
2 |
import json
|
3 |
+
import os
|
|
|
4 |
|
5 |
+
import shortuuid
|
6 |
+
import torch
|
7 |
from llava.conversation import default_conversation
|
8 |
from llava.utils import disable_torch_init
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
|
11 |
|
12 |
|
13 |
# new stopping implementation
|
|
|
18 |
self.start_len = None
|
19 |
self.input_ids = input_ids
|
20 |
|
21 |
+
def __call__(
|
22 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
23 |
+
) -> bool:
|
24 |
if self.start_len is None:
|
25 |
self.start_len = self.input_ids.shape[1]
|
26 |
else:
|
27 |
+
outputs = self.tokenizer.batch_decode(
|
28 |
+
output_ids[:, self.start_len :], skip_special_tokens=True
|
29 |
+
)[0]
|
30 |
for keyword in self.keywords:
|
31 |
if keyword in outputs:
|
32 |
return True
|
|
|
39 |
disable_torch_init()
|
40 |
model_name = os.path.expanduser(model_name)
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
42 |
+
model = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_name, torch_dtype=torch.float16
|
44 |
+
).cuda()
|
45 |
|
46 |
ques_file = open(os.path.expanduser(questions_file), "r")
|
47 |
ans_file = open(os.path.expanduser(answers_file), "w")
|
|
|
60 |
do_sample=True,
|
61 |
temperature=0.7,
|
62 |
max_new_tokens=1024,
|
63 |
+
stopping_criteria=[stopping_criteria],
|
64 |
+
)
|
65 |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
66 |
try:
|
67 |
index = outputs.index(conv.sep, len(prompt))
|
|
|
69 |
outputs += conv.sep
|
70 |
index = outputs.index(conv.sep, len(prompt))
|
71 |
|
72 |
+
outputs = outputs[len(prompt) + len(conv.roles[1]) + 2 : index].strip()
|
73 |
ans_id = shortuuid.uuid()
|
74 |
+
ans_file.write(
|
75 |
+
json.dumps(
|
76 |
+
{
|
77 |
+
"question_id": idx,
|
78 |
+
"text": outputs,
|
79 |
+
"answer_id": ans_id,
|
80 |
+
"model_id": model_name,
|
81 |
+
"metadata": {},
|
82 |
+
}
|
83 |
+
)
|
84 |
+
+ "\n"
|
85 |
+
)
|
86 |
ans_file.flush()
|
87 |
ans_file.close()
|
88 |
|
89 |
+
|
90 |
if __name__ == "__main__":
|
91 |
parser = argparse.ArgumentParser()
|
92 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
model/llava/eval/model_vqa.py
CHANGED
@@ -1,25 +1,25 @@
|
|
1 |
import argparse
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
3 |
-
import torch
|
4 |
-
import os
|
5 |
import json
|
6 |
-
|
7 |
-
import
|
|
|
8 |
|
|
|
|
|
9 |
from llava import LlavaLlamaForCausalLM
|
10 |
from llava.conversation import conv_templates
|
11 |
from llava.utils import disable_torch_init
|
12 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
13 |
-
|
14 |
from PIL import Image
|
15 |
-
import
|
16 |
-
import
|
|
|
|
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
20 |
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
-
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
|
24 |
|
25 |
def get_chunk(lst, n, k):
|
@@ -37,12 +37,14 @@ def patch_config(config):
|
|
37 |
patch_dict = {
|
38 |
"use_mm_proj": True,
|
39 |
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
40 |
-
"mm_hidden_size": 1024
|
41 |
}
|
42 |
|
43 |
cfg = AutoConfig.from_pretrained(config)
|
44 |
if not hasattr(cfg, "mm_vision_tower"):
|
45 |
-
print(
|
|
|
|
|
46 |
for k, v in patch_dict.items():
|
47 |
setattr(cfg, k, v)
|
48 |
cfg.save_pretrained(config)
|
@@ -55,50 +57,84 @@ def eval_model(args):
|
|
55 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
56 |
if args.mm_projector is None:
|
57 |
patch_config(model_name)
|
58 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
59 |
-
|
|
|
|
|
|
|
|
|
60 |
|
61 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
62 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
63 |
if mm_use_im_start_end:
|
64 |
-
tokenizer.add_tokens(
|
|
|
|
|
65 |
|
66 |
vision_tower = model.model.vision_tower[0]
|
67 |
-
vision_tower.to(device=
|
68 |
vision_config = vision_tower.config
|
69 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
70 |
vision_config.use_im_start_end = mm_use_im_start_end
|
71 |
if mm_use_im_start_end:
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
73 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
74 |
else:
|
75 |
# in case of using a pretrained model with only a MLP projector weights
|
76 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
|
|
|
|
77 |
|
78 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
|
81 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
82 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
83 |
if mm_use_im_start_end:
|
84 |
-
tokenizer.add_tokens(
|
|
|
|
|
85 |
|
86 |
vision_config = vision_tower.config
|
87 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
88 |
vision_config.use_im_start_end = mm_use_im_start_end
|
89 |
if mm_use_im_start_end:
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
93 |
|
94 |
-
mm_projector = torch.nn.Linear(
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
97 |
|
98 |
model.model.mm_projector = mm_projector.cuda().half()
|
99 |
model.model.vision_tower = [vision_tower]
|
100 |
|
101 |
-
questions = [
|
|
|
|
|
102 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
103 |
answers_file = os.path.expanduser(args.answers_file)
|
104 |
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
@@ -109,12 +145,18 @@ def eval_model(args):
|
|
109 |
qs = line["text"]
|
110 |
cur_prompt = qs
|
111 |
if mm_use_im_start_end:
|
112 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
else:
|
114 |
-
qs = qs +
|
115 |
|
116 |
-
if args.conv_mode ==
|
117 |
-
qs +=
|
118 |
# conv = default_conversation.copy()
|
119 |
conv = conv_templates[args.conv_mode].copy()
|
120 |
conv.append_message(conv.roles[0], qs)
|
@@ -123,7 +165,9 @@ def eval_model(args):
|
|
123 |
|
124 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
125 |
# image.save(os.path.join(save_image_folder, image_file))
|
126 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
127 |
|
128 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
129 |
|
@@ -135,17 +179,21 @@ def eval_model(args):
|
|
135 |
self.start_len = None
|
136 |
self.input_ids = input_ids
|
137 |
|
138 |
-
def __call__(
|
|
|
|
|
139 |
if self.start_len is None:
|
140 |
self.start_len = self.input_ids.shape[1]
|
141 |
else:
|
142 |
-
outputs = self.tokenizer.batch_decode(
|
|
|
|
|
143 |
for keyword in self.keywords:
|
144 |
if keyword in outputs:
|
145 |
return True
|
146 |
return False
|
147 |
|
148 |
-
keywords = [
|
149 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
150 |
|
151 |
with torch.inference_mode():
|
@@ -155,21 +203,28 @@ def eval_model(args):
|
|
155 |
do_sample=True,
|
156 |
temperature=0.7,
|
157 |
max_new_tokens=1024,
|
158 |
-
stopping_criteria=[stopping_criteria]
|
|
|
159 |
|
160 |
input_token_len = input_ids.shape[1]
|
161 |
-
n_diff_input_output = (
|
|
|
|
|
162 |
if n_diff_input_output > 0:
|
163 |
-
print(
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
167 |
while True:
|
168 |
cur_len = len(outputs)
|
169 |
outputs = outputs.strip()
|
170 |
-
for pattern in [
|
171 |
if outputs.startswith(pattern):
|
172 |
-
outputs = outputs[len(pattern):].strip()
|
173 |
if len(outputs) == cur_len:
|
174 |
break
|
175 |
|
@@ -182,15 +237,23 @@ def eval_model(args):
|
|
182 |
outputs = outputs[:index].strip()
|
183 |
|
184 |
ans_id = shortuuid.uuid()
|
185 |
-
ans_file.write(
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
ans_file.flush()
|
192 |
ans_file.close()
|
193 |
|
|
|
194 |
if __name__ == "__main__":
|
195 |
parser = argparse.ArgumentParser()
|
196 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
|
|
1 |
import argparse
|
|
|
|
|
|
|
2 |
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
|
7 |
+
import shortuuid
|
8 |
+
import torch
|
9 |
from llava import LlavaLlamaForCausalLM
|
10 |
from llava.conversation import conv_templates
|
11 |
from llava.utils import disable_torch_init
|
|
|
|
|
12 |
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
15 |
+
CLIPImageProcessor, CLIPVisionModel,
|
16 |
+
StoppingCriteria)
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
20 |
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
|
24 |
|
25 |
def get_chunk(lst, n, k):
|
|
|
37 |
patch_dict = {
|
38 |
"use_mm_proj": True,
|
39 |
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
40 |
+
"mm_hidden_size": 1024,
|
41 |
}
|
42 |
|
43 |
cfg = AutoConfig.from_pretrained(config)
|
44 |
if not hasattr(cfg, "mm_vision_tower"):
|
45 |
+
print(
|
46 |
+
f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk."
|
47 |
+
)
|
48 |
for k, v in patch_dict.items():
|
49 |
setattr(cfg, k, v)
|
50 |
cfg.save_pretrained(config)
|
|
|
57 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
58 |
if args.mm_projector is None:
|
59 |
patch_config(model_name)
|
60 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
61 |
+
model_name, torch_dtype=torch.float16
|
62 |
+
).cuda()
|
63 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
64 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
65 |
+
)
|
66 |
|
67 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
68 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
69 |
if mm_use_im_start_end:
|
70 |
+
tokenizer.add_tokens(
|
71 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
72 |
+
)
|
73 |
|
74 |
vision_tower = model.model.vision_tower[0]
|
75 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
76 |
vision_config = vision_tower.config
|
77 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
78 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
79 |
+
)[0]
|
80 |
vision_config.use_im_start_end = mm_use_im_start_end
|
81 |
if mm_use_im_start_end:
|
82 |
+
(
|
83 |
+
vision_config.im_start_token,
|
84 |
+
vision_config.im_end_token,
|
85 |
+
) = tokenizer.convert_tokens_to_ids(
|
86 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
87 |
+
)
|
88 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
89 |
else:
|
90 |
# in case of using a pretrained model with only a MLP projector weights
|
91 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
92 |
+
model_name, torch_dtype=torch.float16
|
93 |
+
).cuda()
|
94 |
|
95 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
96 |
+
args.vision_tower, torch_dtype=torch.float16
|
97 |
+
).cuda()
|
98 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
99 |
+
args.vision_tower, torch_dtype=torch.float16
|
100 |
+
)
|
101 |
|
102 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
103 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
104 |
if mm_use_im_start_end:
|
105 |
+
tokenizer.add_tokens(
|
106 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
107 |
+
)
|
108 |
|
109 |
vision_config = vision_tower.config
|
110 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
111 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
112 |
+
)[0]
|
113 |
vision_config.use_im_start_end = mm_use_im_start_end
|
114 |
if mm_use_im_start_end:
|
115 |
+
(
|
116 |
+
vision_config.im_start_token,
|
117 |
+
vision_config.im_end_token,
|
118 |
+
) = tokenizer.convert_tokens_to_ids(
|
119 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
120 |
+
)
|
121 |
|
122 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
123 |
|
124 |
+
mm_projector = torch.nn.Linear(
|
125 |
+
vision_config.hidden_size, model.config.hidden_size
|
126 |
+
)
|
127 |
+
mm_projector_weights = torch.load(args.mm_projector, map_location="cpu")
|
128 |
+
mm_projector.load_state_dict(
|
129 |
+
{k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
|
130 |
+
)
|
131 |
|
132 |
model.model.mm_projector = mm_projector.cuda().half()
|
133 |
model.model.vision_tower = [vision_tower]
|
134 |
|
135 |
+
questions = [
|
136 |
+
json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")
|
137 |
+
]
|
138 |
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
139 |
answers_file = os.path.expanduser(args.answers_file)
|
140 |
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
|
|
145 |
qs = line["text"]
|
146 |
cur_prompt = qs
|
147 |
if mm_use_im_start_end:
|
148 |
+
qs = (
|
149 |
+
qs
|
150 |
+
+ "\n"
|
151 |
+
+ DEFAULT_IM_START_TOKEN
|
152 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
153 |
+
+ DEFAULT_IM_END_TOKEN
|
154 |
+
)
|
155 |
else:
|
156 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
157 |
|
158 |
+
if args.conv_mode == "simple_legacy":
|
159 |
+
qs += "\n\n### Response:"
|
160 |
# conv = default_conversation.copy()
|
161 |
conv = conv_templates[args.conv_mode].copy()
|
162 |
conv.append_message(conv.roles[0], qs)
|
|
|
165 |
|
166 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
167 |
# image.save(os.path.join(save_image_folder, image_file))
|
168 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
169 |
+
"pixel_values"
|
170 |
+
][0]
|
171 |
|
172 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
173 |
|
|
|
179 |
self.start_len = None
|
180 |
self.input_ids = input_ids
|
181 |
|
182 |
+
def __call__(
|
183 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
184 |
+
) -> bool:
|
185 |
if self.start_len is None:
|
186 |
self.start_len = self.input_ids.shape[1]
|
187 |
else:
|
188 |
+
outputs = self.tokenizer.batch_decode(
|
189 |
+
output_ids[:, self.start_len :], skip_special_tokens=True
|
190 |
+
)[0]
|
191 |
for keyword in self.keywords:
|
192 |
if keyword in outputs:
|
193 |
return True
|
194 |
return False
|
195 |
|
196 |
+
keywords = ["###"]
|
197 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
198 |
|
199 |
with torch.inference_mode():
|
|
|
203 |
do_sample=True,
|
204 |
temperature=0.7,
|
205 |
max_new_tokens=1024,
|
206 |
+
stopping_criteria=[stopping_criteria],
|
207 |
+
)
|
208 |
|
209 |
input_token_len = input_ids.shape[1]
|
210 |
+
n_diff_input_output = (
|
211 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
212 |
+
)
|
213 |
if n_diff_input_output > 0:
|
214 |
+
print(
|
215 |
+
f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
|
216 |
+
)
|
217 |
+
outputs = tokenizer.batch_decode(
|
218 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
219 |
+
)[0]
|
220 |
+
|
221 |
+
if args.conv_mode == "simple_legacy" or args.conv_mode == "simple":
|
222 |
while True:
|
223 |
cur_len = len(outputs)
|
224 |
outputs = outputs.strip()
|
225 |
+
for pattern in ["###", "Assistant:", "Response:"]:
|
226 |
if outputs.startswith(pattern):
|
227 |
+
outputs = outputs[len(pattern) :].strip()
|
228 |
if len(outputs) == cur_len:
|
229 |
break
|
230 |
|
|
|
237 |
outputs = outputs[:index].strip()
|
238 |
|
239 |
ans_id = shortuuid.uuid()
|
240 |
+
ans_file.write(
|
241 |
+
json.dumps(
|
242 |
+
{
|
243 |
+
"question_id": idx,
|
244 |
+
"prompt": cur_prompt,
|
245 |
+
"text": outputs,
|
246 |
+
"answer_id": ans_id,
|
247 |
+
"model_id": model_name,
|
248 |
+
"metadata": {},
|
249 |
+
}
|
250 |
+
)
|
251 |
+
+ "\n"
|
252 |
+
)
|
253 |
ans_file.flush()
|
254 |
ans_file.close()
|
255 |
|
256 |
+
|
257 |
if __name__ == "__main__":
|
258 |
parser = argparse.ArgumentParser()
|
259 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
model/llava/eval/model_vqa_science.py
CHANGED
@@ -1,25 +1,25 @@
|
|
1 |
import argparse
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
3 |
-
import torch
|
4 |
-
import os
|
5 |
import json
|
6 |
-
|
7 |
-
import
|
|
|
8 |
|
|
|
|
|
9 |
from llava import LlavaLlamaForCausalLM
|
10 |
from llava.conversation import conv_templates
|
11 |
from llava.utils import disable_torch_init
|
12 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
13 |
-
|
14 |
from PIL import Image
|
15 |
-
import
|
16 |
-
import
|
|
|
|
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
20 |
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
-
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
|
24 |
|
25 |
def get_chunk(lst, n, k):
|
@@ -33,8 +33,6 @@ DEFAULT_IM_START_TOKEN = "<im_start>"
|
|
33 |
DEFAULT_IM_END_TOKEN = "<im_end>"
|
34 |
|
35 |
|
36 |
-
|
37 |
-
|
38 |
detail_describe_instructions = [
|
39 |
"Describe the following image in detail.",
|
40 |
"Provide a detailed description of the given image.",
|
@@ -70,19 +68,21 @@ concise_describe_instructions = [
|
|
70 |
|
71 |
prompt_pool = detail_describe_instructions + concise_describe_instructions
|
72 |
|
73 |
-
prompt_pool = [
|
74 |
|
75 |
|
76 |
def patch_config(config):
|
77 |
patch_dict = {
|
78 |
"use_mm_proj": True,
|
79 |
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
80 |
-
"mm_hidden_size": 1024
|
81 |
}
|
82 |
|
83 |
cfg = AutoConfig.from_pretrained(config)
|
84 |
if not hasattr(cfg, "mm_vision_tower"):
|
85 |
-
print(
|
|
|
|
|
86 |
for k, v in patch_dict.items():
|
87 |
setattr(cfg, k, v)
|
88 |
cfg.save_pretrained(config)
|
@@ -96,11 +96,15 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
96 |
self.start_len = None
|
97 |
self.input_ids = input_ids
|
98 |
|
99 |
-
def __call__(
|
|
|
|
|
100 |
if self.start_len is None:
|
101 |
self.start_len = self.input_ids.shape[1]
|
102 |
else:
|
103 |
-
outputs = self.tokenizer.batch_decode(
|
|
|
|
|
104 |
for keyword in self.keywords:
|
105 |
if keyword in outputs:
|
106 |
return True
|
@@ -114,45 +118,77 @@ def eval_model(args):
|
|
114 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
115 |
if args.mm_projector is None:
|
116 |
patch_config(model_name)
|
117 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
118 |
-
|
|
|
|
|
|
|
|
|
119 |
|
120 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
121 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
122 |
if mm_use_im_start_end:
|
123 |
-
tokenizer.add_tokens(
|
|
|
|
|
124 |
|
125 |
vision_tower = model.model.vision_tower[0]
|
126 |
-
vision_tower.to(device=
|
127 |
vision_config = vision_tower.config
|
128 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
129 |
vision_config.use_im_start_end = mm_use_im_start_end
|
130 |
if mm_use_im_start_end:
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
132 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
133 |
else:
|
134 |
# in case of using a pretrained model with only a MLP projector weights
|
135 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
|
|
|
|
136 |
|
137 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
138 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
139 |
if mm_use_im_start_end:
|
140 |
-
tokenizer.add_tokens(
|
|
|
|
|
141 |
|
142 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
143 |
-
|
|
|
|
|
|
|
|
|
144 |
|
145 |
vision_config = vision_tower.config
|
146 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
147 |
vision_config.use_im_start_end = mm_use_im_start_end
|
148 |
if mm_use_im_start_end:
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
152 |
|
153 |
-
mm_projector = torch.nn.Linear(
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
156 |
|
157 |
model.model.mm_projector = mm_projector.cuda().half()
|
158 |
model.model.vision_tower = [vision_tower]
|
@@ -163,33 +199,43 @@ def eval_model(args):
|
|
163 |
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
164 |
os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True)
|
165 |
ans_file = open(answers_file, "w")
|
166 |
-
save_image_folder = os.path.join(
|
|
|
|
|
167 |
for i, line in enumerate(tqdm(questions)):
|
168 |
idx = line["id"]
|
169 |
-
question = line[
|
170 |
gt_ans = line["conversations"][1]
|
171 |
-
|
172 |
-
qs = question['value']
|
173 |
|
174 |
-
qs =
|
|
|
|
|
175 |
cur_prompt = qs
|
176 |
|
177 |
-
if
|
178 |
image_file = line["image"]
|
179 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
180 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
181 |
images = image_tensor.unsqueeze(0).half().cuda()
|
182 |
-
if getattr(model.config,
|
183 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
else:
|
185 |
-
qs = qs +
|
186 |
-
cur_prompt = cur_prompt +
|
187 |
else:
|
188 |
images = None
|
189 |
|
190 |
-
if args.conv_mode ==
|
191 |
-
qs +=
|
192 |
-
assert gt_ans[
|
193 |
# conv = default_conversation.copy()
|
194 |
conv = conv_templates[args.conv_mode].copy()
|
195 |
conv.append_message(conv.roles[0], qs)
|
@@ -198,7 +244,7 @@ def eval_model(args):
|
|
198 |
|
199 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
200 |
|
201 |
-
keywords = [
|
202 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
203 |
|
204 |
with torch.inference_mode():
|
@@ -208,22 +254,29 @@ def eval_model(args):
|
|
208 |
do_sample=True,
|
209 |
temperature=0.7,
|
210 |
max_new_tokens=1024,
|
211 |
-
stopping_criteria=[stopping_criteria]
|
|
|
212 |
|
213 |
# TODO: new implementation
|
214 |
input_token_len = input_ids.shape[1]
|
215 |
-
n_diff_input_output = (
|
|
|
|
|
216 |
if n_diff_input_output > 0:
|
217 |
-
print(
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
while True:
|
222 |
cur_len = len(outputs)
|
223 |
outputs = outputs.strip()
|
224 |
-
for pattern in [
|
225 |
if outputs.startswith(pattern):
|
226 |
-
outputs = outputs[len(pattern):].strip()
|
227 |
if len(outputs) == cur_len:
|
228 |
break
|
229 |
|
@@ -238,11 +291,11 @@ def eval_model(args):
|
|
238 |
# prompt for answer
|
239 |
if args.answer_prompter:
|
240 |
outputs_reasoning = outputs
|
241 |
-
inputs = tokenizer([prompt + outputs_reasoning +
|
242 |
|
243 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
244 |
|
245 |
-
keywords = [
|
246 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
247 |
|
248 |
with torch.inference_mode():
|
@@ -252,13 +305,20 @@ def eval_model(args):
|
|
252 |
do_sample=True,
|
253 |
temperature=0.7,
|
254 |
max_new_tokens=64,
|
255 |
-
stopping_criteria=[stopping_criteria]
|
|
|
256 |
|
257 |
input_token_len = input_ids.shape[1]
|
258 |
-
n_diff_input_output = (
|
|
|
|
|
259 |
if n_diff_input_output > 0:
|
260 |
-
print(
|
261 |
-
|
|
|
|
|
|
|
|
|
262 |
|
263 |
try:
|
264 |
index = outputs.index(conv.sep)
|
@@ -267,7 +327,7 @@ def eval_model(args):
|
|
267 |
index = outputs.index(conv.sep)
|
268 |
|
269 |
outputs = outputs[:index].strip()
|
270 |
-
outputs = outputs_reasoning +
|
271 |
|
272 |
# new implementation ends
|
273 |
|
@@ -281,17 +341,24 @@ def eval_model(args):
|
|
281 |
|
282 |
# outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
|
283 |
|
284 |
-
|
285 |
ans_id = shortuuid.uuid()
|
286 |
-
ans_file.write(
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
ans_file.flush()
|
293 |
ans_file.close()
|
294 |
|
|
|
295 |
if __name__ == "__main__":
|
296 |
parser = argparse.ArgumentParser()
|
297 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
|
|
1 |
import argparse
|
|
|
|
|
|
|
2 |
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
|
7 |
+
import shortuuid
|
8 |
+
import torch
|
9 |
from llava import LlavaLlamaForCausalLM
|
10 |
from llava.conversation import conv_templates
|
11 |
from llava.utils import disable_torch_init
|
|
|
|
|
12 |
from PIL import Image
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
15 |
+
CLIPImageProcessor, CLIPVisionModel,
|
16 |
+
StoppingCriteria)
|
17 |
|
18 |
|
19 |
def split_list(lst, n):
|
20 |
"""Split a list into n (roughly) equal-sized chunks"""
|
21 |
chunk_size = math.ceil(len(lst) / n) # integer division
|
22 |
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
23 |
|
24 |
|
25 |
def get_chunk(lst, n, k):
|
|
|
33 |
DEFAULT_IM_END_TOKEN = "<im_end>"
|
34 |
|
35 |
|
|
|
|
|
36 |
detail_describe_instructions = [
|
37 |
"Describe the following image in detail.",
|
38 |
"Provide a detailed description of the given image.",
|
|
|
68 |
|
69 |
prompt_pool = detail_describe_instructions + concise_describe_instructions
|
70 |
|
71 |
+
prompt_pool = ["Describe the following image in detail."]
|
72 |
|
73 |
|
74 |
def patch_config(config):
|
75 |
patch_dict = {
|
76 |
"use_mm_proj": True,
|
77 |
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
78 |
+
"mm_hidden_size": 1024,
|
79 |
}
|
80 |
|
81 |
cfg = AutoConfig.from_pretrained(config)
|
82 |
if not hasattr(cfg, "mm_vision_tower"):
|
83 |
+
print(
|
84 |
+
f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk."
|
85 |
+
)
|
86 |
for k, v in patch_dict.items():
|
87 |
setattr(cfg, k, v)
|
88 |
cfg.save_pretrained(config)
|
|
|
96 |
self.start_len = None
|
97 |
self.input_ids = input_ids
|
98 |
|
99 |
+
def __call__(
|
100 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
101 |
+
) -> bool:
|
102 |
if self.start_len is None:
|
103 |
self.start_len = self.input_ids.shape[1]
|
104 |
else:
|
105 |
+
outputs = self.tokenizer.batch_decode(
|
106 |
+
output_ids[:, self.start_len :], skip_special_tokens=True
|
107 |
+
)[0]
|
108 |
for keyword in self.keywords:
|
109 |
if keyword in outputs:
|
110 |
return True
|
|
|
118 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
119 |
if args.mm_projector is None:
|
120 |
patch_config(model_name)
|
121 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
122 |
+
model_name, torch_dtype=torch.float16, use_cache=True
|
123 |
+
).cuda()
|
124 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
125 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
126 |
+
)
|
127 |
|
128 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
129 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
130 |
if mm_use_im_start_end:
|
131 |
+
tokenizer.add_tokens(
|
132 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
133 |
+
)
|
134 |
|
135 |
vision_tower = model.model.vision_tower[0]
|
136 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
137 |
vision_config = vision_tower.config
|
138 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
139 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
140 |
+
)[0]
|
141 |
vision_config.use_im_start_end = mm_use_im_start_end
|
142 |
if mm_use_im_start_end:
|
143 |
+
(
|
144 |
+
vision_config.im_start_token,
|
145 |
+
vision_config.im_end_token,
|
146 |
+
) = tokenizer.convert_tokens_to_ids(
|
147 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
148 |
+
)
|
149 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
150 |
else:
|
151 |
# in case of using a pretrained model with only a MLP projector weights
|
152 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
153 |
+
model_name, torch_dtype=torch.float16, use_cache=True
|
154 |
+
).cuda()
|
155 |
|
156 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
157 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
158 |
if mm_use_im_start_end:
|
159 |
+
tokenizer.add_tokens(
|
160 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
161 |
+
)
|
162 |
|
163 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
164 |
+
args.vision_tower, torch_dtype=torch.float16
|
165 |
+
).cuda()
|
166 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
167 |
+
args.vision_tower, torch_dtype=torch.float16
|
168 |
+
)
|
169 |
|
170 |
vision_config = vision_tower.config
|
171 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
172 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
173 |
+
)[0]
|
174 |
vision_config.use_im_start_end = mm_use_im_start_end
|
175 |
if mm_use_im_start_end:
|
176 |
+
(
|
177 |
+
vision_config.im_start_token,
|
178 |
+
vision_config.im_end_token,
|
179 |
+
) = tokenizer.convert_tokens_to_ids(
|
180 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
181 |
+
)
|
182 |
|
183 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
184 |
|
185 |
+
mm_projector = torch.nn.Linear(
|
186 |
+
vision_config.hidden_size, model.config.hidden_size
|
187 |
+
)
|
188 |
+
mm_projector_weights = torch.load(args.mm_projector, map_location="cpu")
|
189 |
+
mm_projector.load_state_dict(
|
190 |
+
{k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
|
191 |
+
)
|
192 |
|
193 |
model.model.mm_projector = mm_projector.cuda().half()
|
194 |
model.model.vision_tower = [vision_tower]
|
|
|
199 |
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
200 |
os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True)
|
201 |
ans_file = open(answers_file, "w")
|
202 |
+
save_image_folder = os.path.join(
|
203 |
+
os.path.dirname(os.path.expanduser(args.answers_file)), "images"
|
204 |
+
)
|
205 |
for i, line in enumerate(tqdm(questions)):
|
206 |
idx = line["id"]
|
207 |
+
question = line["conversations"][0]
|
208 |
gt_ans = line["conversations"][1]
|
|
|
|
|
209 |
|
210 |
+
qs = question["value"]
|
211 |
+
|
212 |
+
qs = qs.replace("<image>", "").strip()
|
213 |
cur_prompt = qs
|
214 |
|
215 |
+
if "image" in line:
|
216 |
image_file = line["image"]
|
217 |
image = Image.open(os.path.join(args.image_folder, image_file))
|
218 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
219 |
+
"pixel_values"
|
220 |
+
][0]
|
221 |
images = image_tensor.unsqueeze(0).half().cuda()
|
222 |
+
if getattr(model.config, "mm_use_im_start_end", False):
|
223 |
+
qs = (
|
224 |
+
qs
|
225 |
+
+ "\n"
|
226 |
+
+ DEFAULT_IM_START_TOKEN
|
227 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
228 |
+
+ DEFAULT_IM_END_TOKEN
|
229 |
+
)
|
230 |
else:
|
231 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
232 |
+
cur_prompt = cur_prompt + "\n" + "<image>"
|
233 |
else:
|
234 |
images = None
|
235 |
|
236 |
+
if args.conv_mode == "simple_legacy":
|
237 |
+
qs += "\n\n### Response:"
|
238 |
+
assert gt_ans["from"] == "gpt"
|
239 |
# conv = default_conversation.copy()
|
240 |
conv = conv_templates[args.conv_mode].copy()
|
241 |
conv.append_message(conv.roles[0], qs)
|
|
|
244 |
|
245 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
246 |
|
247 |
+
keywords = ["###"]
|
248 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
249 |
|
250 |
with torch.inference_mode():
|
|
|
254 |
do_sample=True,
|
255 |
temperature=0.7,
|
256 |
max_new_tokens=1024,
|
257 |
+
stopping_criteria=[stopping_criteria],
|
258 |
+
)
|
259 |
|
260 |
# TODO: new implementation
|
261 |
input_token_len = input_ids.shape[1]
|
262 |
+
n_diff_input_output = (
|
263 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
264 |
+
)
|
265 |
if n_diff_input_output > 0:
|
266 |
+
print(
|
267 |
+
f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
|
268 |
+
)
|
269 |
+
outputs = tokenizer.batch_decode(
|
270 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
271 |
+
)[0]
|
272 |
+
|
273 |
+
if args.conv_mode == "simple_legacy":
|
274 |
while True:
|
275 |
cur_len = len(outputs)
|
276 |
outputs = outputs.strip()
|
277 |
+
for pattern in ["###", "Assistant:", "Response:"]:
|
278 |
if outputs.startswith(pattern):
|
279 |
+
outputs = outputs[len(pattern) :].strip()
|
280 |
if len(outputs) == cur_len:
|
281 |
break
|
282 |
|
|
|
291 |
# prompt for answer
|
292 |
if args.answer_prompter:
|
293 |
outputs_reasoning = outputs
|
294 |
+
inputs = tokenizer([prompt + outputs_reasoning + " ###\nANSWER:"])
|
295 |
|
296 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
297 |
|
298 |
+
keywords = ["###"]
|
299 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
300 |
|
301 |
with torch.inference_mode():
|
|
|
305 |
do_sample=True,
|
306 |
temperature=0.7,
|
307 |
max_new_tokens=64,
|
308 |
+
stopping_criteria=[stopping_criteria],
|
309 |
+
)
|
310 |
|
311 |
input_token_len = input_ids.shape[1]
|
312 |
+
n_diff_input_output = (
|
313 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
314 |
+
)
|
315 |
if n_diff_input_output > 0:
|
316 |
+
print(
|
317 |
+
f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
|
318 |
+
)
|
319 |
+
outputs = tokenizer.batch_decode(
|
320 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
321 |
+
)[0]
|
322 |
|
323 |
try:
|
324 |
index = outputs.index(conv.sep)
|
|
|
327 |
index = outputs.index(conv.sep)
|
328 |
|
329 |
outputs = outputs[:index].strip()
|
330 |
+
outputs = outputs_reasoning + "\n The answer is " + outputs
|
331 |
|
332 |
# new implementation ends
|
333 |
|
|
|
341 |
|
342 |
# outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
|
343 |
|
|
|
344 |
ans_id = shortuuid.uuid()
|
345 |
+
ans_file.write(
|
346 |
+
json.dumps(
|
347 |
+
{
|
348 |
+
"question_id": idx,
|
349 |
+
"prompt": cur_prompt,
|
350 |
+
"text": outputs,
|
351 |
+
"answer_id": ans_id,
|
352 |
+
"model_id": model_name,
|
353 |
+
"metadata": {},
|
354 |
+
}
|
355 |
+
)
|
356 |
+
+ "\n"
|
357 |
+
)
|
358 |
ans_file.flush()
|
359 |
ans_file.close()
|
360 |
|
361 |
+
|
362 |
if __name__ == "__main__":
|
363 |
parser = argparse.ArgumentParser()
|
364 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
model/llava/eval/qa_baseline_gpt35.py
CHANGED
@@ -1,51 +1,57 @@
|
|
1 |
"""Generate answers with GPT-3.5"""
|
2 |
# Note: you need to be using OpenAI Python v0.27.0 for the code below to work
|
3 |
import argparse
|
|
|
4 |
import json
|
5 |
import os
|
6 |
import time
|
7 |
-
import concurrent.futures
|
8 |
|
9 |
import openai
|
10 |
-
import tqdm
|
11 |
import shortuuid
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
MODEL = 'gpt-3.5-turbo'
|
14 |
-
MODEL_ID = 'gpt-3.5-turbo:20230327'
|
15 |
|
16 |
def get_answer(question_id: int, question: str, max_tokens: int):
|
17 |
ans = {
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
}
|
22 |
for _ in range(3):
|
23 |
try:
|
24 |
response = openai.ChatCompletion.create(
|
25 |
model=MODEL,
|
26 |
-
messages=[
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
max_tokens=max_tokens,
|
34 |
)
|
35 |
-
ans[
|
36 |
return ans
|
37 |
except Exception as e:
|
38 |
-
print(
|
39 |
-
ans[
|
40 |
time.sleep(1)
|
41 |
return ans
|
42 |
|
43 |
|
44 |
-
if __name__ ==
|
45 |
-
parser = argparse.ArgumentParser(description=
|
46 |
-
parser.add_argument(
|
47 |
-
parser.add_argument(
|
48 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
|
|
49 |
args = parser.parse_args()
|
50 |
|
51 |
questions_dict = {}
|
@@ -54,7 +60,7 @@ if __name__ == '__main__':
|
|
54 |
if not line:
|
55 |
continue
|
56 |
q = json.loads(line)
|
57 |
-
questions_dict[q[
|
58 |
|
59 |
answers = []
|
60 |
|
@@ -64,11 +70,13 @@ if __name__ == '__main__':
|
|
64 |
future = executor.submit(get_answer, qid, question, args.max_tokens)
|
65 |
futures.append(future)
|
66 |
|
67 |
-
for future in tqdm.tqdm(
|
|
|
|
|
68 |
answers.append(future.result())
|
69 |
|
70 |
-
answers.sort(key=lambda x: x[
|
71 |
|
72 |
-
with open(os.path.expanduser(args.output),
|
73 |
table = [json.dumps(ans) for ans in answers]
|
74 |
-
f.write(
|
|
|
1 |
"""Generate answers with GPT-3.5"""
|
2 |
# Note: you need to be using OpenAI Python v0.27.0 for the code below to work
|
3 |
import argparse
|
4 |
+
import concurrent.futures
|
5 |
import json
|
6 |
import os
|
7 |
import time
|
|
|
8 |
|
9 |
import openai
|
|
|
10 |
import shortuuid
|
11 |
+
import tqdm
|
12 |
+
|
13 |
+
MODEL = "gpt-3.5-turbo"
|
14 |
+
MODEL_ID = "gpt-3.5-turbo:20230327"
|
15 |
|
|
|
|
|
16 |
|
17 |
def get_answer(question_id: int, question: str, max_tokens: int):
|
18 |
ans = {
|
19 |
+
"answer_id": shortuuid.uuid(),
|
20 |
+
"question_id": question_id,
|
21 |
+
"model_id": MODEL_ID,
|
22 |
}
|
23 |
for _ in range(3):
|
24 |
try:
|
25 |
response = openai.ChatCompletion.create(
|
26 |
model=MODEL,
|
27 |
+
messages=[
|
28 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
29 |
+
{
|
30 |
+
"role": "user",
|
31 |
+
"content": question,
|
32 |
+
},
|
33 |
+
],
|
34 |
max_tokens=max_tokens,
|
35 |
)
|
36 |
+
ans["text"] = response["choices"][0]["message"]["content"]
|
37 |
return ans
|
38 |
except Exception as e:
|
39 |
+
print("[ERROR]", e)
|
40 |
+
ans["text"] = "#ERROR#"
|
41 |
time.sleep(1)
|
42 |
return ans
|
43 |
|
44 |
|
45 |
+
if __name__ == "__main__":
|
46 |
+
parser = argparse.ArgumentParser(description="ChatGPT answer generation.")
|
47 |
+
parser.add_argument("-q", "--question")
|
48 |
+
parser.add_argument("-o", "--output")
|
49 |
+
parser.add_argument(
|
50 |
+
"--max-tokens",
|
51 |
+
type=int,
|
52 |
+
default=1024,
|
53 |
+
help="maximum number of tokens produced in the output",
|
54 |
+
)
|
55 |
args = parser.parse_args()
|
56 |
|
57 |
questions_dict = {}
|
|
|
60 |
if not line:
|
61 |
continue
|
62 |
q = json.loads(line)
|
63 |
+
questions_dict[q["question_id"]] = q["text"]
|
64 |
|
65 |
answers = []
|
66 |
|
|
|
70 |
future = executor.submit(get_answer, qid, question, args.max_tokens)
|
71 |
futures.append(future)
|
72 |
|
73 |
+
for future in tqdm.tqdm(
|
74 |
+
concurrent.futures.as_completed(futures), total=len(futures)
|
75 |
+
):
|
76 |
answers.append(future.result())
|
77 |
|
78 |
+
answers.sort(key=lambda x: x["question_id"])
|
79 |
|
80 |
+
with open(os.path.expanduser(args.output), "w") as f:
|
81 |
table = [json.dumps(ans) for ans in answers]
|
82 |
+
f.write("\n".join(table))
|
model/llava/eval/run_llava.py
CHANGED
@@ -1,20 +1,17 @@
|
|
1 |
import argparse
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
-
import torch
|
4 |
import os
|
5 |
-
from
|
6 |
-
from llava.utils import disable_torch_init
|
7 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
8 |
-
from llava.model import *
|
9 |
-
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
|
13 |
-
import os
|
14 |
import requests
|
|
|
|
|
|
|
|
|
|
|
15 |
from PIL import Image
|
16 |
-
from
|
17 |
-
|
|
|
18 |
|
19 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
20 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -23,11 +20,11 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
23 |
|
24 |
|
25 |
def load_image(image_file):
|
26 |
-
if image_file.startswith(
|
27 |
response = requests.get(image_file)
|
28 |
-
image = Image.open(BytesIO(response.content)).convert(
|
29 |
else:
|
30 |
-
image = Image.open(image_file).convert(
|
31 |
return image
|
32 |
|
33 |
|
@@ -38,35 +35,63 @@ def eval_model(args):
|
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
39 |
|
40 |
if "mpt" in model_name.lower():
|
41 |
-
model = LlavaMPTForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
42 |
else:
|
43 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
44 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
47 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
48 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
49 |
if mm_use_im_start_end:
|
50 |
-
tokenizer.add_tokens(
|
|
|
|
|
51 |
|
52 |
vision_tower = model.get_model().vision_tower[0]
|
53 |
-
if vision_tower.device.type ==
|
54 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
|
|
|
|
|
|
|
|
55 |
model.get_model().vision_tower[0] = vision_tower
|
56 |
else:
|
57 |
-
vision_tower.to(device=
|
58 |
vision_config = vision_tower.config
|
59 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
60 |
vision_config.use_im_start_end = mm_use_im_start_end
|
61 |
if mm_use_im_start_end:
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
64 |
|
65 |
qs = args.query
|
66 |
if mm_use_im_start_end:
|
67 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
else:
|
69 |
-
qs = qs +
|
70 |
|
71 |
if "v1" in model_name.lower():
|
72 |
conv_mode = "llava_v1"
|
@@ -76,7 +101,11 @@ def eval_model(args):
|
|
76 |
conv_mode = "multimodal"
|
77 |
|
78 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
79 |
-
print(
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
args.conv_mode = conv_mode
|
82 |
|
@@ -87,7 +116,9 @@ def eval_model(args):
|
|
87 |
inputs = tokenizer([prompt])
|
88 |
|
89 |
image = load_image(args.image_file)
|
90 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
91 |
|
92 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
93 |
|
@@ -96,20 +127,34 @@ def eval_model(args):
|
|
96 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
97 |
|
98 |
with torch.inference_mode():
|
99 |
-
output_ids = model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
input_token_len = input_ids.shape[1]
|
102 |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
103 |
if n_diff_input_output > 0:
|
104 |
-
print(
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
outputs = outputs.strip()
|
107 |
if outputs.endswith(stop_str):
|
108 |
-
outputs = outputs[
|
109 |
outputs = outputs.strip()
|
110 |
print(outputs)
|
111 |
|
112 |
-
import pdb
|
|
|
|
|
|
|
113 |
|
114 |
if __name__ == "__main__":
|
115 |
parser = argparse.ArgumentParser()
|
|
|
1 |
import argparse
|
|
|
|
|
2 |
import os
|
3 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
5 |
import requests
|
6 |
+
import torch
|
7 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
8 |
+
from llava.model import *
|
9 |
+
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
+
from llava.utils import disable_torch_init
|
11 |
from PIL import Image
|
12 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
13 |
+
CLIPImageProcessor, CLIPVisionModel,
|
14 |
+
StoppingCriteria)
|
15 |
|
16 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
17 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
20 |
|
21 |
|
22 |
def load_image(image_file):
|
23 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
24 |
response = requests.get(image_file)
|
25 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
26 |
else:
|
27 |
+
image = Image.open(image_file).convert("RGB")
|
28 |
return image
|
29 |
|
30 |
|
|
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
36 |
|
37 |
if "mpt" in model_name.lower():
|
38 |
+
model = LlavaMPTForCausalLM.from_pretrained(
|
39 |
+
model_name,
|
40 |
+
low_cpu_mem_usage=True,
|
41 |
+
torch_dtype=torch.float16,
|
42 |
+
use_cache=True,
|
43 |
+
).cuda()
|
44 |
else:
|
45 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
46 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
47 |
+
model_name, torch_dtype=torch.float16, device_map="auto"
|
48 |
+
) # .cuda()
|
49 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
50 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
51 |
+
)
|
52 |
|
53 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
54 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
55 |
if mm_use_im_start_end:
|
56 |
+
tokenizer.add_tokens(
|
57 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
58 |
+
)
|
59 |
|
60 |
vision_tower = model.get_model().vision_tower[0]
|
61 |
+
if vision_tower.device.type == "meta":
|
62 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
63 |
+
vision_tower.config._name_or_path,
|
64 |
+
torch_dtype=torch.float16,
|
65 |
+
low_cpu_mem_usage=True,
|
66 |
+
).cuda()
|
67 |
model.get_model().vision_tower[0] = vision_tower
|
68 |
else:
|
69 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
70 |
vision_config = vision_tower.config
|
71 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
72 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
73 |
+
)[0]
|
74 |
vision_config.use_im_start_end = mm_use_im_start_end
|
75 |
if mm_use_im_start_end:
|
76 |
+
(
|
77 |
+
vision_config.im_start_token,
|
78 |
+
vision_config.im_end_token,
|
79 |
+
) = tokenizer.convert_tokens_to_ids(
|
80 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
81 |
+
)
|
82 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
83 |
|
84 |
qs = args.query
|
85 |
if mm_use_im_start_end:
|
86 |
+
qs = (
|
87 |
+
qs
|
88 |
+
+ "\n"
|
89 |
+
+ DEFAULT_IM_START_TOKEN
|
90 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
91 |
+
+ DEFAULT_IM_END_TOKEN
|
92 |
+
)
|
93 |
else:
|
94 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
95 |
|
96 |
if "v1" in model_name.lower():
|
97 |
conv_mode = "llava_v1"
|
|
|
101 |
conv_mode = "multimodal"
|
102 |
|
103 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
104 |
+
print(
|
105 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
106 |
+
conv_mode, args.conv_mode, args.conv_mode
|
107 |
+
)
|
108 |
+
)
|
109 |
else:
|
110 |
args.conv_mode = conv_mode
|
111 |
|
|
|
116 |
inputs = tokenizer([prompt])
|
117 |
|
118 |
image = load_image(args.image_file)
|
119 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
120 |
+
"pixel_values"
|
121 |
+
][0]
|
122 |
|
123 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
124 |
|
|
|
127 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
128 |
|
129 |
with torch.inference_mode():
|
130 |
+
output_ids = model.generate(
|
131 |
+
input_ids,
|
132 |
+
images=image_tensor.unsqueeze(0).half().cuda(),
|
133 |
+
do_sample=True,
|
134 |
+
temperature=0.2,
|
135 |
+
max_new_tokens=1024,
|
136 |
+
stopping_criteria=[stopping_criteria],
|
137 |
+
)
|
138 |
|
139 |
input_token_len = input_ids.shape[1]
|
140 |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
141 |
if n_diff_input_output > 0:
|
142 |
+
print(
|
143 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
144 |
+
)
|
145 |
+
outputs = tokenizer.batch_decode(
|
146 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
147 |
+
)[0]
|
148 |
outputs = outputs.strip()
|
149 |
if outputs.endswith(stop_str):
|
150 |
+
outputs = outputs[: -len(stop_str)]
|
151 |
outputs = outputs.strip()
|
152 |
print(outputs)
|
153 |
|
154 |
+
import pdb
|
155 |
+
|
156 |
+
pdb.set_trace()
|
157 |
+
|
158 |
|
159 |
if __name__ == "__main__":
|
160 |
parser = argparse.ArgumentParser()
|
model/llava/eval/run_llava_batch.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
import argparse
|
2 |
-
|
3 |
-
import
|
4 |
-
import os
|
5 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
-
from llava.utils import disable_torch_init
|
7 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
8 |
-
from llava.model import *
|
9 |
-
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
import os
|
14 |
-
import requests
|
15 |
-
from PIL import Image
|
16 |
from io import BytesIO
|
17 |
|
18 |
-
import glob
|
19 |
import numpy as np
|
20 |
-
import
|
|
|
21 |
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
24 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
27 |
|
28 |
|
29 |
def load_image(image_file):
|
30 |
-
if image_file.startswith(
|
31 |
response = requests.get(image_file)
|
32 |
-
image = Image.open(BytesIO(response.content)).convert(
|
33 |
else:
|
34 |
-
image = Image.open(image_file).convert(
|
35 |
return image
|
36 |
|
37 |
|
38 |
-
classes = [
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def eval_model(args):
|
68 |
# Model
|
@@ -71,35 +193,58 @@ def eval_model(args):
|
|
71 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
|
73 |
if "mpt" in model_name.lower():
|
74 |
-
model = LlavaMPTForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
77 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
81 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
82 |
if mm_use_im_start_end:
|
83 |
-
tokenizer.add_tokens(
|
|
|
|
|
84 |
|
85 |
vision_tower = model.get_model().vision_tower[0]
|
86 |
-
if vision_tower.device.type ==
|
87 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
|
|
|
|
|
|
|
|
88 |
model.get_model().vision_tower[0] = vision_tower
|
89 |
else:
|
90 |
-
vision_tower.to(device=
|
91 |
vision_config = vision_tower.config
|
92 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
93 |
vision_config.use_im_start_end = mm_use_im_start_end
|
94 |
if mm_use_im_start_end:
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
97 |
|
98 |
# paths for all images
|
99 |
-
images = sorted(
|
|
|
|
|
100 |
results = []
|
101 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
102 |
-
|
103 |
# if i == 2:
|
104 |
# break
|
105 |
|
@@ -109,7 +254,9 @@ def eval_model(args):
|
|
109 |
print("i: {}, len(images): {}".format(i, len(images)))
|
110 |
|
111 |
image = load_image(image_file)
|
112 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
113 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
114 |
|
115 |
label_file = image_file.replace("images", "annotations").replace(".jpg", ".png")
|
@@ -126,9 +273,15 @@ def eval_model(args):
|
|
126 |
|
127 |
# qs = args.query
|
128 |
if mm_use_im_start_end:
|
129 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
else:
|
131 |
-
qs = qs +
|
132 |
|
133 |
if "v1" in model_name.lower():
|
134 |
conv_mode = "llava_v1"
|
@@ -138,7 +291,11 @@ def eval_model(args):
|
|
138 |
conv_mode = "multimodal"
|
139 |
|
140 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
141 |
-
print(
|
|
|
|
|
|
|
|
|
142 |
else:
|
143 |
args.conv_mode = conv_mode
|
144 |
|
@@ -164,27 +321,41 @@ def eval_model(args):
|
|
164 |
images=image_tensor,
|
165 |
do_sample=True,
|
166 |
temperature=0.2,
|
167 |
-
max_new_tokens=512,
|
168 |
-
stopping_criteria=[stopping_criteria]
|
|
|
169 |
|
170 |
input_token_len = input_ids.shape[1]
|
171 |
-
n_diff_input_output = (
|
|
|
|
|
172 |
if n_diff_input_output > 0:
|
173 |
-
print(
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
outputs = outputs.strip()
|
176 |
if outputs.endswith(stop_str):
|
177 |
-
outputs = outputs[
|
178 |
outputs = outputs.strip()
|
179 |
|
180 |
print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
181 |
|
182 |
-
results.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
185 |
json.dump(results, f)
|
186 |
|
187 |
-
|
|
|
188 |
|
189 |
if __name__ == "__main__":
|
190 |
parser = argparse.ArgumentParser()
|
|
|
1 |
import argparse
|
2 |
+
import glob
|
3 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
|
|
|
|
5 |
from io import BytesIO
|
6 |
|
|
|
7 |
import numpy as np
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
import tqdm
|
11 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
12 |
+
from llava.model import *
|
13 |
+
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
+
from llava.utils import disable_torch_init
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
17 |
+
CLIPImageProcessor, CLIPVisionModel,
|
18 |
+
StoppingCriteria)
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
24 |
|
25 |
|
26 |
def load_image(image_file):
|
27 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
28 |
response = requests.get(image_file)
|
29 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
30 |
else:
|
31 |
+
image = Image.open(image_file).convert("RGB")
|
32 |
return image
|
33 |
|
34 |
|
35 |
+
classes = [
|
36 |
+
"wall",
|
37 |
+
"building",
|
38 |
+
"sky",
|
39 |
+
"floor",
|
40 |
+
"tree",
|
41 |
+
"ceiling",
|
42 |
+
"road",
|
43 |
+
"bed",
|
44 |
+
"windowpane",
|
45 |
+
"grass",
|
46 |
+
"cabinet",
|
47 |
+
"sidewalk",
|
48 |
+
"person",
|
49 |
+
"earth",
|
50 |
+
"door",
|
51 |
+
"table",
|
52 |
+
"mountain",
|
53 |
+
"plant",
|
54 |
+
"curtain",
|
55 |
+
"chair",
|
56 |
+
"car",
|
57 |
+
"water",
|
58 |
+
"painting",
|
59 |
+
"sofa",
|
60 |
+
"shelf",
|
61 |
+
"house",
|
62 |
+
"sea",
|
63 |
+
"mirror",
|
64 |
+
"rug",
|
65 |
+
"field",
|
66 |
+
"armchair",
|
67 |
+
"seat",
|
68 |
+
"fence",
|
69 |
+
"desk",
|
70 |
+
"rock",
|
71 |
+
"wardrobe",
|
72 |
+
"lamp",
|
73 |
+
"bathtub",
|
74 |
+
"railing",
|
75 |
+
"cushion",
|
76 |
+
"base",
|
77 |
+
"box",
|
78 |
+
"column",
|
79 |
+
"signboard",
|
80 |
+
"chest of drawers",
|
81 |
+
"counter",
|
82 |
+
"sand",
|
83 |
+
"sink",
|
84 |
+
"skyscraper",
|
85 |
+
"fireplace",
|
86 |
+
"refrigerator",
|
87 |
+
"grandstand",
|
88 |
+
"path",
|
89 |
+
"stairs",
|
90 |
+
"runway",
|
91 |
+
"case",
|
92 |
+
"pool table",
|
93 |
+
"pillow",
|
94 |
+
"screen door",
|
95 |
+
"stairway",
|
96 |
+
"river",
|
97 |
+
"bridge",
|
98 |
+
"bookcase",
|
99 |
+
"blind",
|
100 |
+
"coffee table",
|
101 |
+
"toilet",
|
102 |
+
"flower",
|
103 |
+
"book",
|
104 |
+
"hill",
|
105 |
+
"bench",
|
106 |
+
"countertop",
|
107 |
+
"stove",
|
108 |
+
"palm",
|
109 |
+
"kitchen island",
|
110 |
+
"computer",
|
111 |
+
"swivel chair",
|
112 |
+
"boat",
|
113 |
+
"bar",
|
114 |
+
"arcade machine",
|
115 |
+
"hovel",
|
116 |
+
"bus",
|
117 |
+
"towel",
|
118 |
+
"light",
|
119 |
+
"truck",
|
120 |
+
"tower",
|
121 |
+
"chandelier",
|
122 |
+
"awning",
|
123 |
+
"streetlight",
|
124 |
+
"booth",
|
125 |
+
"television receiver",
|
126 |
+
"airplane",
|
127 |
+
"dirt track",
|
128 |
+
"apparel",
|
129 |
+
"pole",
|
130 |
+
"land",
|
131 |
+
"bannister",
|
132 |
+
"escalator",
|
133 |
+
"ottoman",
|
134 |
+
"bottle",
|
135 |
+
"buffet",
|
136 |
+
"poster",
|
137 |
+
"stage",
|
138 |
+
"van",
|
139 |
+
"ship",
|
140 |
+
"fountain",
|
141 |
+
"conveyer belt",
|
142 |
+
"canopy",
|
143 |
+
"washer",
|
144 |
+
"plaything",
|
145 |
+
"swimming pool",
|
146 |
+
"stool",
|
147 |
+
"barrel",
|
148 |
+
"basket",
|
149 |
+
"waterfall",
|
150 |
+
"tent",
|
151 |
+
"bag",
|
152 |
+
"minibike",
|
153 |
+
"cradle",
|
154 |
+
"oven",
|
155 |
+
"ball",
|
156 |
+
"food",
|
157 |
+
"step",
|
158 |
+
"tank",
|
159 |
+
"trade name",
|
160 |
+
"microwave",
|
161 |
+
"pot",
|
162 |
+
"animal",
|
163 |
+
"bicycle",
|
164 |
+
"lake",
|
165 |
+
"dishwasher",
|
166 |
+
"screen",
|
167 |
+
"blanket",
|
168 |
+
"sculpture",
|
169 |
+
"hood",
|
170 |
+
"sconce",
|
171 |
+
"vase",
|
172 |
+
"traffic light",
|
173 |
+
"tray",
|
174 |
+
"ashcan",
|
175 |
+
"fan",
|
176 |
+
"pier",
|
177 |
+
"crt screen",
|
178 |
+
"plate",
|
179 |
+
"monitor",
|
180 |
+
"bulletin board",
|
181 |
+
"shower",
|
182 |
+
"radiator",
|
183 |
+
"glass",
|
184 |
+
"clock",
|
185 |
+
"flag",
|
186 |
+
]
|
187 |
+
|
188 |
|
189 |
def eval_model(args):
|
190 |
# Model
|
|
|
193 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
194 |
|
195 |
if "mpt" in model_name.lower():
|
196 |
+
model = LlavaMPTForCausalLM.from_pretrained(
|
197 |
+
model_name,
|
198 |
+
low_cpu_mem_usage=True,
|
199 |
+
torch_dtype=torch.float16,
|
200 |
+
use_cache=True,
|
201 |
+
).cuda()
|
202 |
else:
|
203 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
204 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
205 |
+
model_name, torch_dtype=torch.float16, device_map="auto"
|
206 |
+
) # .cuda()
|
207 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
208 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
209 |
+
)
|
210 |
|
211 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
212 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
213 |
if mm_use_im_start_end:
|
214 |
+
tokenizer.add_tokens(
|
215 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
216 |
+
)
|
217 |
|
218 |
vision_tower = model.get_model().vision_tower[0]
|
219 |
+
if vision_tower.device.type == "meta":
|
220 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
221 |
+
vision_tower.config._name_or_path,
|
222 |
+
torch_dtype=torch.float16,
|
223 |
+
low_cpu_mem_usage=True,
|
224 |
+
).cuda()
|
225 |
model.get_model().vision_tower[0] = vision_tower
|
226 |
else:
|
227 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
228 |
vision_config = vision_tower.config
|
229 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
230 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
231 |
+
)[0]
|
232 |
vision_config.use_im_start_end = mm_use_im_start_end
|
233 |
if mm_use_im_start_end:
|
234 |
+
(
|
235 |
+
vision_config.im_start_token,
|
236 |
+
vision_config.im_end_token,
|
237 |
+
) = tokenizer.convert_tokens_to_ids(
|
238 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
239 |
+
)
|
240 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
241 |
|
242 |
# paths for all images
|
243 |
+
images = sorted(
|
244 |
+
glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
|
245 |
+
)
|
246 |
results = []
|
247 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
|
|
248 |
# if i == 2:
|
249 |
# break
|
250 |
|
|
|
254 |
print("i: {}, len(images): {}".format(i, len(images)))
|
255 |
|
256 |
image = load_image(image_file)
|
257 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
258 |
+
"pixel_values"
|
259 |
+
][0]
|
260 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
261 |
|
262 |
label_file = image_file.replace("images", "annotations").replace(".jpg", ".png")
|
|
|
273 |
|
274 |
# qs = args.query
|
275 |
if mm_use_im_start_end:
|
276 |
+
qs = (
|
277 |
+
qs
|
278 |
+
+ "\n"
|
279 |
+
+ DEFAULT_IM_START_TOKEN
|
280 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
281 |
+
+ DEFAULT_IM_END_TOKEN
|
282 |
+
)
|
283 |
else:
|
284 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
285 |
|
286 |
if "v1" in model_name.lower():
|
287 |
conv_mode = "llava_v1"
|
|
|
291 |
conv_mode = "multimodal"
|
292 |
|
293 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
294 |
+
print(
|
295 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
296 |
+
conv_mode, args.conv_mode, args.conv_mode
|
297 |
+
)
|
298 |
+
)
|
299 |
else:
|
300 |
args.conv_mode = conv_mode
|
301 |
|
|
|
321 |
images=image_tensor,
|
322 |
do_sample=True,
|
323 |
temperature=0.2,
|
324 |
+
max_new_tokens=512, # 1024,
|
325 |
+
stopping_criteria=[stopping_criteria],
|
326 |
+
)
|
327 |
|
328 |
input_token_len = input_ids.shape[1]
|
329 |
+
n_diff_input_output = (
|
330 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
331 |
+
)
|
332 |
if n_diff_input_output > 0:
|
333 |
+
print(
|
334 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
335 |
+
)
|
336 |
+
outputs = tokenizer.batch_decode(
|
337 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
338 |
+
)[0]
|
339 |
outputs = outputs.strip()
|
340 |
if outputs.endswith(stop_str):
|
341 |
+
outputs = outputs[: -len(stop_str)]
|
342 |
outputs = outputs.strip()
|
343 |
|
344 |
print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
345 |
|
346 |
+
results.append(
|
347 |
+
{
|
348 |
+
"image_id": image_file.split("/")[-1],
|
349 |
+
"input": input_conv,
|
350 |
+
"output": outputs,
|
351 |
+
}
|
352 |
+
)
|
353 |
|
354 |
with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
355 |
json.dump(results, f)
|
356 |
|
357 |
+
# print(outputs)
|
358 |
+
|
359 |
|
360 |
if __name__ == "__main__":
|
361 |
parser = argparse.ArgumentParser()
|
model/llava/eval/run_llava_batch_v2.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
import argparse
|
2 |
-
|
3 |
-
import
|
4 |
-
import os
|
5 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
-
from llava.utils import disable_torch_init
|
7 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
8 |
-
from llava.model import *
|
9 |
-
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
import os
|
14 |
-
import requests
|
15 |
-
from PIL import Image
|
16 |
from io import BytesIO
|
17 |
|
18 |
-
import glob
|
19 |
import numpy as np
|
20 |
-
import
|
|
|
21 |
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
24 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
27 |
|
28 |
|
29 |
def load_image(image_file):
|
30 |
-
if image_file.startswith(
|
31 |
response = requests.get(image_file)
|
32 |
-
image = Image.open(BytesIO(response.content)).convert(
|
33 |
else:
|
34 |
-
image = Image.open(image_file).convert(
|
35 |
return image
|
36 |
|
37 |
|
38 |
-
classes = [
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def eval_model(args):
|
68 |
# Model
|
@@ -71,35 +193,58 @@ def eval_model(args):
|
|
71 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
|
73 |
if "mpt" in model_name.lower():
|
74 |
-
model = LlavaMPTForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
77 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
81 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
82 |
if mm_use_im_start_end:
|
83 |
-
tokenizer.add_tokens(
|
|
|
|
|
84 |
|
85 |
vision_tower = model.get_model().vision_tower[0]
|
86 |
-
if vision_tower.device.type ==
|
87 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
|
|
|
|
|
|
|
|
88 |
model.get_model().vision_tower[0] = vision_tower
|
89 |
# else:
|
90 |
-
|
91 |
vision_config = vision_tower.config
|
92 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
93 |
vision_config.use_im_start_end = mm_use_im_start_end
|
94 |
if mm_use_im_start_end:
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
97 |
|
98 |
# paths for all images
|
99 |
-
images = sorted(
|
|
|
|
|
100 |
results = []
|
101 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
102 |
-
|
103 |
# if i == 2:
|
104 |
# break
|
105 |
|
@@ -115,7 +260,9 @@ def eval_model(args):
|
|
115 |
label_unique = np.unique(label)
|
116 |
|
117 |
image = load_image(image_file)
|
118 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
119 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
120 |
|
121 |
for label in label_unique:
|
@@ -128,9 +275,15 @@ def eval_model(args):
|
|
128 |
|
129 |
# qs = args.query
|
130 |
if mm_use_im_start_end:
|
131 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
else:
|
133 |
-
qs = qs +
|
134 |
|
135 |
if "v1" in model_name.lower():
|
136 |
conv_mode = "llava_v1"
|
@@ -140,7 +293,11 @@ def eval_model(args):
|
|
140 |
conv_mode = "multimodal"
|
141 |
|
142 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
143 |
-
print(
|
|
|
|
|
|
|
|
|
144 |
else:
|
145 |
args.conv_mode = conv_mode
|
146 |
|
@@ -173,32 +330,46 @@ def eval_model(args):
|
|
173 |
images=image_tensor,
|
174 |
# do_sample=True,
|
175 |
# temperature=0.2,
|
176 |
-
max_new_tokens=512,
|
177 |
-
stopping_criteria=[stopping_criteria]
|
|
|
178 |
|
179 |
input_token_len = input_ids.shape[1]
|
180 |
-
n_diff_input_output = (
|
|
|
|
|
181 |
if n_diff_input_output > 0:
|
182 |
-
print(
|
|
|
|
|
183 |
|
184 |
outputs_list = []
|
185 |
for output_id in output_ids:
|
186 |
-
outputs = tokenizer.batch_decode(
|
|
|
|
|
187 |
outputs = outputs.strip()
|
188 |
if outputs.endswith(stop_str):
|
189 |
-
outputs = outputs[
|
190 |
outputs = outputs.strip()
|
191 |
outputs_list.append(outputs)
|
192 |
|
193 |
for qs, outputs in zip(prompt_list, outputs_list):
|
194 |
print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
195 |
|
196 |
-
results.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
199 |
json.dump(results, f)
|
200 |
|
201 |
-
|
|
|
202 |
|
203 |
if __name__ == "__main__":
|
204 |
parser = argparse.ArgumentParser()
|
|
|
1 |
import argparse
|
2 |
+
import glob
|
3 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
|
|
|
|
5 |
from io import BytesIO
|
6 |
|
|
|
7 |
import numpy as np
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
import tqdm
|
11 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
12 |
+
from llava.model import *
|
13 |
+
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
+
from llava.utils import disable_torch_init
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
17 |
+
CLIPImageProcessor, CLIPVisionModel,
|
18 |
+
StoppingCriteria)
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
24 |
|
25 |
|
26 |
def load_image(image_file):
|
27 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
28 |
response = requests.get(image_file)
|
29 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
30 |
else:
|
31 |
+
image = Image.open(image_file).convert("RGB")
|
32 |
return image
|
33 |
|
34 |
|
35 |
+
classes = [
|
36 |
+
"wall",
|
37 |
+
"building",
|
38 |
+
"sky",
|
39 |
+
"floor",
|
40 |
+
"tree",
|
41 |
+
"ceiling",
|
42 |
+
"road",
|
43 |
+
"bed",
|
44 |
+
"windowpane",
|
45 |
+
"grass",
|
46 |
+
"cabinet",
|
47 |
+
"sidewalk",
|
48 |
+
"person",
|
49 |
+
"earth",
|
50 |
+
"door",
|
51 |
+
"table",
|
52 |
+
"mountain",
|
53 |
+
"plant",
|
54 |
+
"curtain",
|
55 |
+
"chair",
|
56 |
+
"car",
|
57 |
+
"water",
|
58 |
+
"painting",
|
59 |
+
"sofa",
|
60 |
+
"shelf",
|
61 |
+
"house",
|
62 |
+
"sea",
|
63 |
+
"mirror",
|
64 |
+
"rug",
|
65 |
+
"field",
|
66 |
+
"armchair",
|
67 |
+
"seat",
|
68 |
+
"fence",
|
69 |
+
"desk",
|
70 |
+
"rock",
|
71 |
+
"wardrobe",
|
72 |
+
"lamp",
|
73 |
+
"bathtub",
|
74 |
+
"railing",
|
75 |
+
"cushion",
|
76 |
+
"base",
|
77 |
+
"box",
|
78 |
+
"column",
|
79 |
+
"signboard",
|
80 |
+
"chest of drawers",
|
81 |
+
"counter",
|
82 |
+
"sand",
|
83 |
+
"sink",
|
84 |
+
"skyscraper",
|
85 |
+
"fireplace",
|
86 |
+
"refrigerator",
|
87 |
+
"grandstand",
|
88 |
+
"path",
|
89 |
+
"stairs",
|
90 |
+
"runway",
|
91 |
+
"case",
|
92 |
+
"pool table",
|
93 |
+
"pillow",
|
94 |
+
"screen door",
|
95 |
+
"stairway",
|
96 |
+
"river",
|
97 |
+
"bridge",
|
98 |
+
"bookcase",
|
99 |
+
"blind",
|
100 |
+
"coffee table",
|
101 |
+
"toilet",
|
102 |
+
"flower",
|
103 |
+
"book",
|
104 |
+
"hill",
|
105 |
+
"bench",
|
106 |
+
"countertop",
|
107 |
+
"stove",
|
108 |
+
"palm",
|
109 |
+
"kitchen island",
|
110 |
+
"computer",
|
111 |
+
"swivel chair",
|
112 |
+
"boat",
|
113 |
+
"bar",
|
114 |
+
"arcade machine",
|
115 |
+
"hovel",
|
116 |
+
"bus",
|
117 |
+
"towel",
|
118 |
+
"light",
|
119 |
+
"truck",
|
120 |
+
"tower",
|
121 |
+
"chandelier",
|
122 |
+
"awning",
|
123 |
+
"streetlight",
|
124 |
+
"booth",
|
125 |
+
"television receiver",
|
126 |
+
"airplane",
|
127 |
+
"dirt track",
|
128 |
+
"apparel",
|
129 |
+
"pole",
|
130 |
+
"land",
|
131 |
+
"bannister",
|
132 |
+
"escalator",
|
133 |
+
"ottoman",
|
134 |
+
"bottle",
|
135 |
+
"buffet",
|
136 |
+
"poster",
|
137 |
+
"stage",
|
138 |
+
"van",
|
139 |
+
"ship",
|
140 |
+
"fountain",
|
141 |
+
"conveyer belt",
|
142 |
+
"canopy",
|
143 |
+
"washer",
|
144 |
+
"plaything",
|
145 |
+
"swimming pool",
|
146 |
+
"stool",
|
147 |
+
"barrel",
|
148 |
+
"basket",
|
149 |
+
"waterfall",
|
150 |
+
"tent",
|
151 |
+
"bag",
|
152 |
+
"minibike",
|
153 |
+
"cradle",
|
154 |
+
"oven",
|
155 |
+
"ball",
|
156 |
+
"food",
|
157 |
+
"step",
|
158 |
+
"tank",
|
159 |
+
"trade name",
|
160 |
+
"microwave",
|
161 |
+
"pot",
|
162 |
+
"animal",
|
163 |
+
"bicycle",
|
164 |
+
"lake",
|
165 |
+
"dishwasher",
|
166 |
+
"screen",
|
167 |
+
"blanket",
|
168 |
+
"sculpture",
|
169 |
+
"hood",
|
170 |
+
"sconce",
|
171 |
+
"vase",
|
172 |
+
"traffic light",
|
173 |
+
"tray",
|
174 |
+
"ashcan",
|
175 |
+
"fan",
|
176 |
+
"pier",
|
177 |
+
"crt screen",
|
178 |
+
"plate",
|
179 |
+
"monitor",
|
180 |
+
"bulletin board",
|
181 |
+
"shower",
|
182 |
+
"radiator",
|
183 |
+
"glass",
|
184 |
+
"clock",
|
185 |
+
"flag",
|
186 |
+
]
|
187 |
+
|
188 |
|
189 |
def eval_model(args):
|
190 |
# Model
|
|
|
193 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
194 |
|
195 |
if "mpt" in model_name.lower():
|
196 |
+
model = LlavaMPTForCausalLM.from_pretrained(
|
197 |
+
model_name,
|
198 |
+
low_cpu_mem_usage=True,
|
199 |
+
torch_dtype=torch.float16,
|
200 |
+
use_cache=True,
|
201 |
+
).cuda()
|
202 |
else:
|
203 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
204 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
205 |
+
model_name, torch_dtype=torch.float16, device_map="auto"
|
206 |
+
) # .cuda()
|
207 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
208 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
209 |
+
)
|
210 |
|
211 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
212 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
213 |
if mm_use_im_start_end:
|
214 |
+
tokenizer.add_tokens(
|
215 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
216 |
+
)
|
217 |
|
218 |
vision_tower = model.get_model().vision_tower[0]
|
219 |
+
if vision_tower.device.type == "meta":
|
220 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
221 |
+
vision_tower.config._name_or_path,
|
222 |
+
torch_dtype=torch.float16,
|
223 |
+
low_cpu_mem_usage=True,
|
224 |
+
).cuda()
|
225 |
model.get_model().vision_tower[0] = vision_tower
|
226 |
# else:
|
227 |
+
# vision_tower.to(device='cuda', dtype=torch.float16)
|
228 |
vision_config = vision_tower.config
|
229 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
230 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
231 |
+
)[0]
|
232 |
vision_config.use_im_start_end = mm_use_im_start_end
|
233 |
if mm_use_im_start_end:
|
234 |
+
(
|
235 |
+
vision_config.im_start_token,
|
236 |
+
vision_config.im_end_token,
|
237 |
+
) = tokenizer.convert_tokens_to_ids(
|
238 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
239 |
+
)
|
240 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
241 |
|
242 |
# paths for all images
|
243 |
+
images = sorted(
|
244 |
+
glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
|
245 |
+
)
|
246 |
results = []
|
247 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
|
|
248 |
# if i == 2:
|
249 |
# break
|
250 |
|
|
|
260 |
label_unique = np.unique(label)
|
261 |
|
262 |
image = load_image(image_file)
|
263 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
264 |
+
"pixel_values"
|
265 |
+
][0]
|
266 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
267 |
|
268 |
for label in label_unique:
|
|
|
275 |
|
276 |
# qs = args.query
|
277 |
if mm_use_im_start_end:
|
278 |
+
qs = (
|
279 |
+
qs
|
280 |
+
+ "\n"
|
281 |
+
+ DEFAULT_IM_START_TOKEN
|
282 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
283 |
+
+ DEFAULT_IM_END_TOKEN
|
284 |
+
)
|
285 |
else:
|
286 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
287 |
|
288 |
if "v1" in model_name.lower():
|
289 |
conv_mode = "llava_v1"
|
|
|
293 |
conv_mode = "multimodal"
|
294 |
|
295 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
296 |
+
print(
|
297 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
298 |
+
conv_mode, args.conv_mode, args.conv_mode
|
299 |
+
)
|
300 |
+
)
|
301 |
else:
|
302 |
args.conv_mode = conv_mode
|
303 |
|
|
|
330 |
images=image_tensor,
|
331 |
# do_sample=True,
|
332 |
# temperature=0.2,
|
333 |
+
max_new_tokens=512, # 1024,
|
334 |
+
stopping_criteria=[stopping_criteria],
|
335 |
+
)
|
336 |
|
337 |
input_token_len = input_ids.shape[1]
|
338 |
+
n_diff_input_output = (
|
339 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
340 |
+
)
|
341 |
if n_diff_input_output > 0:
|
342 |
+
print(
|
343 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
344 |
+
)
|
345 |
|
346 |
outputs_list = []
|
347 |
for output_id in output_ids:
|
348 |
+
outputs = tokenizer.batch_decode(
|
349 |
+
output_id[:, input_token_len:], skip_special_tokens=True
|
350 |
+
)[0]
|
351 |
outputs = outputs.strip()
|
352 |
if outputs.endswith(stop_str):
|
353 |
+
outputs = outputs[: -len(stop_str)]
|
354 |
outputs = outputs.strip()
|
355 |
outputs_list.append(outputs)
|
356 |
|
357 |
for qs, outputs in zip(prompt_list, outputs_list):
|
358 |
print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
359 |
|
360 |
+
results.append(
|
361 |
+
{
|
362 |
+
"image_id": image_file.split("/")[-1],
|
363 |
+
"input": prompt_list,
|
364 |
+
"output": outputs_list,
|
365 |
+
}
|
366 |
+
)
|
367 |
|
368 |
with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
369 |
json.dump(results, f)
|
370 |
|
371 |
+
# print(outputs)
|
372 |
+
|
373 |
|
374 |
if __name__ == "__main__":
|
375 |
parser = argparse.ArgumentParser()
|
model/llava/eval/run_llava_batch_v3.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
import argparse
|
2 |
-
|
3 |
-
import
|
4 |
-
import os
|
5 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
6 |
-
from llava.utils import disable_torch_init
|
7 |
-
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
|
8 |
-
from llava.model import *
|
9 |
-
from llava.model.utils import KeywordsStoppingCriteria
|
10 |
-
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
import os
|
14 |
-
import requests
|
15 |
-
from PIL import Image
|
16 |
from io import BytesIO
|
17 |
|
18 |
-
import glob
|
19 |
import numpy as np
|
20 |
-
import
|
|
|
21 |
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
24 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
27 |
|
28 |
|
29 |
def load_image(image_file):
|
30 |
-
if image_file.startswith(
|
31 |
response = requests.get(image_file)
|
32 |
-
image = Image.open(BytesIO(response.content)).convert(
|
33 |
else:
|
34 |
-
image = Image.open(image_file).convert(
|
35 |
return image
|
36 |
|
37 |
|
38 |
-
classes = [
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def eval_model(args):
|
68 |
# Model
|
@@ -71,38 +193,61 @@ def eval_model(args):
|
|
71 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
72 |
|
73 |
if "mpt" in model_name.lower():
|
74 |
-
model = LlavaMPTForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
77 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
|
80 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
81 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
82 |
if mm_use_im_start_end:
|
83 |
-
tokenizer.add_tokens(
|
|
|
|
|
84 |
|
85 |
vision_tower = model.get_model().vision_tower[0]
|
86 |
-
if vision_tower.device.type ==
|
87 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
|
|
|
|
|
|
|
|
88 |
model.get_model().vision_tower[0] = vision_tower
|
89 |
else:
|
90 |
-
vision_tower.to(device=
|
91 |
vision_config = vision_tower.config
|
92 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
93 |
vision_config.use_im_start_end = mm_use_im_start_end
|
94 |
if mm_use_im_start_end:
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
97 |
|
98 |
# paths for all images
|
99 |
-
images = sorted(
|
|
|
|
|
100 |
start, end = args.range.split(",")
|
101 |
start, end = int(start), int(end)
|
102 |
images = images[start:end]
|
103 |
results = []
|
104 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
105 |
-
|
106 |
# if i == 2:
|
107 |
# break
|
108 |
|
@@ -112,7 +257,9 @@ def eval_model(args):
|
|
112 |
print("i: {}, len(images): {}".format(i, len(images)))
|
113 |
|
114 |
image = load_image(image_file)
|
115 |
-
image_tensor = image_processor.preprocess(image, return_tensors=
|
|
|
|
|
116 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
117 |
|
118 |
prompt_list = []
|
@@ -133,9 +280,15 @@ def eval_model(args):
|
|
133 |
|
134 |
# qs = args.query
|
135 |
if mm_use_im_start_end:
|
136 |
-
qs =
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
else:
|
138 |
-
qs = qs +
|
139 |
|
140 |
if "v1" in model_name.lower():
|
141 |
conv_mode = "llava_v1"
|
@@ -145,7 +298,11 @@ def eval_model(args):
|
|
145 |
conv_mode = "multimodal"
|
146 |
|
147 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
148 |
-
print(
|
|
|
|
|
|
|
|
|
149 |
else:
|
150 |
args.conv_mode = conv_mode
|
151 |
|
@@ -171,17 +328,24 @@ def eval_model(args):
|
|
171 |
images=image_tensor,
|
172 |
do_sample=True,
|
173 |
temperature=0.2,
|
174 |
-
max_new_tokens=512,
|
175 |
-
stopping_criteria=[stopping_criteria]
|
|
|
176 |
|
177 |
input_token_len = input_ids.shape[1]
|
178 |
-
n_diff_input_output = (
|
|
|
|
|
179 |
if n_diff_input_output > 0:
|
180 |
-
print(
|
181 |
-
|
|
|
|
|
|
|
|
|
182 |
outputs = outputs.strip()
|
183 |
if outputs.endswith(stop_str):
|
184 |
-
outputs = outputs[
|
185 |
outputs = outputs.strip()
|
186 |
|
187 |
# print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
@@ -190,13 +354,23 @@ def eval_model(args):
|
|
190 |
# results.append({'image_id': image_id, 'input': input_conv, 'output': outputs})
|
191 |
output_list.append(outputs)
|
192 |
image_id = image_file.split("/")[-1].split(".")[0]
|
193 |
-
with open(
|
194 |
-
json.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
# with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
197 |
# json.dump(results, f)
|
198 |
|
199 |
-
|
|
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
parser = argparse.ArgumentParser()
|
|
|
1 |
import argparse
|
2 |
+
import glob
|
3 |
+
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
|
|
|
|
5 |
from io import BytesIO
|
6 |
|
|
|
7 |
import numpy as np
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
import tqdm
|
11 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
12 |
+
from llava.model import *
|
13 |
+
from llava.model.utils import KeywordsStoppingCriteria
|
14 |
+
from llava.utils import disable_torch_init
|
15 |
+
from PIL import Image
|
16 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
17 |
+
CLIPImageProcessor, CLIPVisionModel,
|
18 |
+
StoppingCriteria)
|
19 |
|
20 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
21 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
24 |
|
25 |
|
26 |
def load_image(image_file):
|
27 |
+
if image_file.startswith("http") or image_file.startswith("https"):
|
28 |
response = requests.get(image_file)
|
29 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
30 |
else:
|
31 |
+
image = Image.open(image_file).convert("RGB")
|
32 |
return image
|
33 |
|
34 |
|
35 |
+
classes = [
|
36 |
+
"wall",
|
37 |
+
"building",
|
38 |
+
"sky",
|
39 |
+
"floor",
|
40 |
+
"tree",
|
41 |
+
"ceiling",
|
42 |
+
"road",
|
43 |
+
"bed",
|
44 |
+
"windowpane",
|
45 |
+
"grass",
|
46 |
+
"cabinet",
|
47 |
+
"sidewalk",
|
48 |
+
"person",
|
49 |
+
"earth",
|
50 |
+
"door",
|
51 |
+
"table",
|
52 |
+
"mountain",
|
53 |
+
"plant",
|
54 |
+
"curtain",
|
55 |
+
"chair",
|
56 |
+
"car",
|
57 |
+
"water",
|
58 |
+
"painting",
|
59 |
+
"sofa",
|
60 |
+
"shelf",
|
61 |
+
"house",
|
62 |
+
"sea",
|
63 |
+
"mirror",
|
64 |
+
"rug",
|
65 |
+
"field",
|
66 |
+
"armchair",
|
67 |
+
"seat",
|
68 |
+
"fence",
|
69 |
+
"desk",
|
70 |
+
"rock",
|
71 |
+
"wardrobe",
|
72 |
+
"lamp",
|
73 |
+
"bathtub",
|
74 |
+
"railing",
|
75 |
+
"cushion",
|
76 |
+
"base",
|
77 |
+
"box",
|
78 |
+
"column",
|
79 |
+
"signboard",
|
80 |
+
"chest of drawers",
|
81 |
+
"counter",
|
82 |
+
"sand",
|
83 |
+
"sink",
|
84 |
+
"skyscraper",
|
85 |
+
"fireplace",
|
86 |
+
"refrigerator",
|
87 |
+
"grandstand",
|
88 |
+
"path",
|
89 |
+
"stairs",
|
90 |
+
"runway",
|
91 |
+
"case",
|
92 |
+
"pool table",
|
93 |
+
"pillow",
|
94 |
+
"screen door",
|
95 |
+
"stairway",
|
96 |
+
"river",
|
97 |
+
"bridge",
|
98 |
+
"bookcase",
|
99 |
+
"blind",
|
100 |
+
"coffee table",
|
101 |
+
"toilet",
|
102 |
+
"flower",
|
103 |
+
"book",
|
104 |
+
"hill",
|
105 |
+
"bench",
|
106 |
+
"countertop",
|
107 |
+
"stove",
|
108 |
+
"palm",
|
109 |
+
"kitchen island",
|
110 |
+
"computer",
|
111 |
+
"swivel chair",
|
112 |
+
"boat",
|
113 |
+
"bar",
|
114 |
+
"arcade machine",
|
115 |
+
"hovel",
|
116 |
+
"bus",
|
117 |
+
"towel",
|
118 |
+
"light",
|
119 |
+
"truck",
|
120 |
+
"tower",
|
121 |
+
"chandelier",
|
122 |
+
"awning",
|
123 |
+
"streetlight",
|
124 |
+
"booth",
|
125 |
+
"television receiver",
|
126 |
+
"airplane",
|
127 |
+
"dirt track",
|
128 |
+
"apparel",
|
129 |
+
"pole",
|
130 |
+
"land",
|
131 |
+
"bannister",
|
132 |
+
"escalator",
|
133 |
+
"ottoman",
|
134 |
+
"bottle",
|
135 |
+
"buffet",
|
136 |
+
"poster",
|
137 |
+
"stage",
|
138 |
+
"van",
|
139 |
+
"ship",
|
140 |
+
"fountain",
|
141 |
+
"conveyer belt",
|
142 |
+
"canopy",
|
143 |
+
"washer",
|
144 |
+
"plaything",
|
145 |
+
"swimming pool",
|
146 |
+
"stool",
|
147 |
+
"barrel",
|
148 |
+
"basket",
|
149 |
+
"waterfall",
|
150 |
+
"tent",
|
151 |
+
"bag",
|
152 |
+
"minibike",
|
153 |
+
"cradle",
|
154 |
+
"oven",
|
155 |
+
"ball",
|
156 |
+
"food",
|
157 |
+
"step",
|
158 |
+
"tank",
|
159 |
+
"trade name",
|
160 |
+
"microwave",
|
161 |
+
"pot",
|
162 |
+
"animal",
|
163 |
+
"bicycle",
|
164 |
+
"lake",
|
165 |
+
"dishwasher",
|
166 |
+
"screen",
|
167 |
+
"blanket",
|
168 |
+
"sculpture",
|
169 |
+
"hood",
|
170 |
+
"sconce",
|
171 |
+
"vase",
|
172 |
+
"traffic light",
|
173 |
+
"tray",
|
174 |
+
"ashcan",
|
175 |
+
"fan",
|
176 |
+
"pier",
|
177 |
+
"crt screen",
|
178 |
+
"plate",
|
179 |
+
"monitor",
|
180 |
+
"bulletin board",
|
181 |
+
"shower",
|
182 |
+
"radiator",
|
183 |
+
"glass",
|
184 |
+
"clock",
|
185 |
+
"flag",
|
186 |
+
]
|
187 |
+
|
188 |
|
189 |
def eval_model(args):
|
190 |
# Model
|
|
|
193 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
194 |
|
195 |
if "mpt" in model_name.lower():
|
196 |
+
model = LlavaMPTForCausalLM.from_pretrained(
|
197 |
+
model_name,
|
198 |
+
low_cpu_mem_usage=True,
|
199 |
+
torch_dtype=torch.float16,
|
200 |
+
use_cache=True,
|
201 |
+
).cuda()
|
202 |
else:
|
203 |
# model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
|
204 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
205 |
+
model_name, torch_dtype=torch.float16, device_map="auto"
|
206 |
+
) # .cuda()
|
207 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
208 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
209 |
+
)
|
210 |
|
211 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
212 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
213 |
if mm_use_im_start_end:
|
214 |
+
tokenizer.add_tokens(
|
215 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
216 |
+
)
|
217 |
|
218 |
vision_tower = model.get_model().vision_tower[0]
|
219 |
+
if vision_tower.device.type == "meta":
|
220 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
221 |
+
vision_tower.config._name_or_path,
|
222 |
+
torch_dtype=torch.float16,
|
223 |
+
low_cpu_mem_usage=True,
|
224 |
+
).cuda()
|
225 |
model.get_model().vision_tower[0] = vision_tower
|
226 |
else:
|
227 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
228 |
vision_config = vision_tower.config
|
229 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
230 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
231 |
+
)[0]
|
232 |
vision_config.use_im_start_end = mm_use_im_start_end
|
233 |
if mm_use_im_start_end:
|
234 |
+
(
|
235 |
+
vision_config.im_start_token,
|
236 |
+
vision_config.im_end_token,
|
237 |
+
) = tokenizer.convert_tokens_to_ids(
|
238 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
239 |
+
)
|
240 |
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
|
241 |
|
242 |
# paths for all images
|
243 |
+
images = sorted(
|
244 |
+
glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
|
245 |
+
)
|
246 |
start, end = args.range.split(",")
|
247 |
start, end = int(start), int(end)
|
248 |
images = images[start:end]
|
249 |
results = []
|
250 |
for i, image_file in enumerate(tqdm.tqdm(images)):
|
|
|
251 |
# if i == 2:
|
252 |
# break
|
253 |
|
|
|
257 |
print("i: {}, len(images): {}".format(i, len(images)))
|
258 |
|
259 |
image = load_image(image_file)
|
260 |
+
image_tensor = image_processor.preprocess(image, return_tensors="pt")[
|
261 |
+
"pixel_values"
|
262 |
+
][0]
|
263 |
image_tensor = image_tensor.unsqueeze(0).half().cuda()
|
264 |
|
265 |
prompt_list = []
|
|
|
280 |
|
281 |
# qs = args.query
|
282 |
if mm_use_im_start_end:
|
283 |
+
qs = (
|
284 |
+
qs
|
285 |
+
+ "\n"
|
286 |
+
+ DEFAULT_IM_START_TOKEN
|
287 |
+
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
288 |
+
+ DEFAULT_IM_END_TOKEN
|
289 |
+
)
|
290 |
else:
|
291 |
+
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
292 |
|
293 |
if "v1" in model_name.lower():
|
294 |
conv_mode = "llava_v1"
|
|
|
298 |
conv_mode = "multimodal"
|
299 |
|
300 |
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
301 |
+
print(
|
302 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
303 |
+
conv_mode, args.conv_mode, args.conv_mode
|
304 |
+
)
|
305 |
+
)
|
306 |
else:
|
307 |
args.conv_mode = conv_mode
|
308 |
|
|
|
328 |
images=image_tensor,
|
329 |
do_sample=True,
|
330 |
temperature=0.2,
|
331 |
+
max_new_tokens=512, # 1024,
|
332 |
+
stopping_criteria=[stopping_criteria],
|
333 |
+
)
|
334 |
|
335 |
input_token_len = input_ids.shape[1]
|
336 |
+
n_diff_input_output = (
|
337 |
+
(input_ids != output_ids[:, :input_token_len]).sum().item()
|
338 |
+
)
|
339 |
if n_diff_input_output > 0:
|
340 |
+
print(
|
341 |
+
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
|
342 |
+
)
|
343 |
+
outputs = tokenizer.batch_decode(
|
344 |
+
output_ids[:, input_token_len:], skip_special_tokens=True
|
345 |
+
)[0]
|
346 |
outputs = outputs.strip()
|
347 |
if outputs.endswith(stop_str):
|
348 |
+
outputs = outputs[: -len(stop_str)]
|
349 |
outputs = outputs.strip()
|
350 |
|
351 |
# print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
|
|
|
354 |
# results.append({'image_id': image_id, 'input': input_conv, 'output': outputs})
|
355 |
output_list.append(outputs)
|
356 |
image_id = image_file.split("/")[-1].split(".")[0]
|
357 |
+
with open(
|
358 |
+
"/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+"
|
359 |
+
) as f:
|
360 |
+
json.dump(
|
361 |
+
{
|
362 |
+
"image_id": image_id,
|
363 |
+
"input_list": prompt_list,
|
364 |
+
"output_list": output_list,
|
365 |
+
},
|
366 |
+
f,
|
367 |
+
)
|
368 |
|
369 |
# with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
|
370 |
# json.dump(results, f)
|
371 |
|
372 |
+
# print(outputs)
|
373 |
+
|
374 |
|
375 |
if __name__ == "__main__":
|
376 |
parser = argparse.ArgumentParser()
|
model/llava/eval/summarize_gpt_review.py
CHANGED
@@ -4,23 +4,25 @@ from collections import defaultdict
|
|
4 |
|
5 |
import numpy as np
|
6 |
|
7 |
-
|
8 |
-
if __name__ == '__main__':
|
9 |
base_dir = "vqa/reviews/coco2014_val80"
|
10 |
-
review_files = [
|
|
|
|
|
|
|
|
|
11 |
|
12 |
for review_file in sorted(review_files):
|
13 |
-
config = review_file.replace(
|
14 |
scores = defaultdict(list)
|
15 |
-
print(f
|
16 |
with open(os.path.join(base_dir, review_file)) as f:
|
17 |
for review_str in f:
|
18 |
review = json.loads(review_str)
|
19 |
-
scores[review[
|
20 |
-
scores[
|
21 |
for k, v in scores.items():
|
22 |
stats = np.asarray(v).mean(0).tolist()
|
23 |
stats = [round(x, 3) for x in stats]
|
24 |
-
print(k, stats, round(stats[1]/stats[0]*100, 1))
|
25 |
-
print(
|
26 |
-
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
|
7 |
+
if __name__ == "__main__":
|
|
|
8 |
base_dir = "vqa/reviews/coco2014_val80"
|
9 |
+
review_files = [
|
10 |
+
x
|
11 |
+
for x in os.listdir(base_dir)
|
12 |
+
if x.endswith(".jsonl") and x.startswith("gpt4_text")
|
13 |
+
]
|
14 |
|
15 |
for review_file in sorted(review_files):
|
16 |
+
config = review_file.replace("gpt4_text_", "").replace(".jsonl", "")
|
17 |
scores = defaultdict(list)
|
18 |
+
print(f"GPT-4 vs. {config}")
|
19 |
with open(os.path.join(base_dir, review_file)) as f:
|
20 |
for review_str in f:
|
21 |
review = json.loads(review_str)
|
22 |
+
scores[review["category"]].append(review["tuple"])
|
23 |
+
scores["all"].append(review["tuple"])
|
24 |
for k, v in scores.items():
|
25 |
stats = np.asarray(v).mean(0).tolist()
|
26 |
stats = [round(x, 3) for x in stats]
|
27 |
+
print(k, stats, round(stats[1] / stats[0] * 100, 1))
|
28 |
+
print("=================================")
|
|
model/llava/model/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .llava import
|
2 |
-
from .llava_mpt import
|
|
|
1 |
+
from .llava import LlavaConfig, LlavaLlamaForCausalLM
|
2 |
+
from .llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM
|
model/llava/model/apply_delta.py
CHANGED
@@ -5,32 +5,40 @@ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
8 |
-
from tqdm import tqdm
|
9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
from llava import LlavaLlamaForCausalLM
|
|
|
|
|
11 |
|
12 |
|
13 |
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
print("Loading base model")
|
15 |
base = AutoModelForCausalLM.from_pretrained(
|
16 |
-
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
|
|
17 |
|
18 |
print("Loading delta")
|
19 |
-
delta = LlavaLlamaForCausalLM.from_pretrained(
|
|
|
|
|
20 |
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
21 |
|
22 |
print("Applying delta")
|
23 |
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
24 |
if name not in base.state_dict():
|
25 |
-
assert name in [
|
|
|
|
|
|
|
26 |
continue
|
27 |
if param.data.shape == base.state_dict()[name].shape:
|
28 |
param.data += base.state_dict()[name]
|
29 |
else:
|
30 |
-
assert name in [
|
31 |
-
|
|
|
|
|
32 |
bparam = base.state_dict()[name]
|
33 |
-
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
|
34 |
|
35 |
print("Saving target model")
|
36 |
delta.save_pretrained(target_model_path)
|
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
|
|
|
|
8 |
from llava import LlavaLlamaForCausalLM
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
|
12 |
|
13 |
def apply_delta(base_model_path, target_model_path, delta_path):
|
14 |
print("Loading base model")
|
15 |
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
17 |
+
)
|
18 |
|
19 |
print("Loading delta")
|
20 |
+
delta = LlavaLlamaForCausalLM.from_pretrained(
|
21 |
+
delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
22 |
+
)
|
23 |
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
|
24 |
|
25 |
print("Applying delta")
|
26 |
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
27 |
if name not in base.state_dict():
|
28 |
+
assert name in [
|
29 |
+
"model.mm_projector.weight",
|
30 |
+
"model.mm_projector.bias",
|
31 |
+
], f"{name} not in base model"
|
32 |
continue
|
33 |
if param.data.shape == base.state_dict()[name].shape:
|
34 |
param.data += base.state_dict()[name]
|
35 |
else:
|
36 |
+
assert name in [
|
37 |
+
"model.embed_tokens.weight",
|
38 |
+
"lm_head.weight",
|
39 |
+
], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
40 |
bparam = base.state_dict()[name]
|
41 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
|
42 |
|
43 |
print("Saving target model")
|
44 |
delta.save_pretrained(target_model_path)
|
model/llava/model/consolidate.py
CHANGED
@@ -5,15 +5,17 @@ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
8 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
from llava.model import *
|
10 |
from llava.model.utils import auto_upgrade
|
|
|
11 |
|
12 |
|
13 |
def consolidate_ckpt(src_path, dst_path):
|
14 |
print("Loading model")
|
15 |
auto_upgrade(src_path)
|
16 |
-
src_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
17 |
src_tokenizer = AutoTokenizer.from_pretrained(src_path)
|
18 |
src_model.save_pretrained(dst_path)
|
19 |
src_tokenizer.save_pretrained(dst_path)
|
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
|
|
8 |
from llava.model import *
|
9 |
from llava.model.utils import auto_upgrade
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
|
12 |
|
13 |
def consolidate_ckpt(src_path, dst_path):
|
14 |
print("Loading model")
|
15 |
auto_upgrade(src_path)
|
16 |
+
src_model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
18 |
+
)
|
19 |
src_tokenizer = AutoTokenizer.from_pretrained(src_path)
|
20 |
src_model.save_pretrained(dst_path)
|
21 |
src_tokenizer.save_pretrained(dst_path)
|
model/llava/model/llava.py
CHANGED
@@ -19,13 +19,11 @@ import torch
|
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn import CrossEntropyLoss
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
28 |
-
|
29 |
|
30 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
31 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -45,25 +43,33 @@ class LlavaLlamaModel(LlamaModel):
|
|
45 |
|
46 |
if hasattr(config, "mm_vision_tower"):
|
47 |
# HACK: for FSDP
|
48 |
-
self.vision_tower = [
|
|
|
|
|
49 |
|
50 |
if hasattr(config, "use_mm_proj"):
|
51 |
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
52 |
|
53 |
-
def initialize_vision_modules(
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
self.config.mm_vision_tower = vision_tower
|
56 |
|
57 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
58 |
|
59 |
-
if not hasattr(self,
|
60 |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
61 |
else:
|
62 |
vision_tower = self.vision_tower[0]
|
63 |
vision_tower.requires_grad_(False)
|
64 |
-
if precision ==
|
65 |
vision_tower = vision_tower.to(torch.bfloat16)
|
66 |
-
elif precision ==
|
67 |
vision_tower = vision_tower.to(torch.half)
|
68 |
else:
|
69 |
vision_tower = vision_tower.to(torch.float32)
|
@@ -77,17 +83,23 @@ class LlavaLlamaModel(LlamaModel):
|
|
77 |
self.config.mm_hidden_size = vision_config.hidden_size
|
78 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
79 |
|
80 |
-
if not hasattr(self,
|
81 |
-
self.mm_projector = nn.Linear(
|
|
|
|
|
82 |
|
83 |
if pretrain_mm_mlp_adapter is not None:
|
84 |
-
mm_projector_weights = torch.load(
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
|
87 |
return dict(
|
88 |
image_processor=image_processor,
|
89 |
image_token_len=num_patches,
|
90 |
-
vision_config=vision_config
|
91 |
)
|
92 |
|
93 |
def forward(
|
@@ -102,15 +114,18 @@ class LlavaLlamaModel(LlamaModel):
|
|
102 |
images: Optional[torch.FloatTensor] = None,
|
103 |
return_dict: Optional[bool] = None,
|
104 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
105 |
-
|
106 |
# HACK: replace back original embeddings for LLaVA pretraining
|
107 |
-
orig_embeds_params = getattr(self,
|
108 |
-
|
109 |
if inputs_embeds is None:
|
110 |
inputs_embeds = self.embed_tokens(input_ids)
|
111 |
|
112 |
-
vision_tower = getattr(self,
|
113 |
-
if
|
|
|
|
|
|
|
|
|
114 |
# TODO: this is a modified multimodal LLM -- Haotian Liu
|
115 |
vision_tower = vision_tower[0] # HACK: for FSDP
|
116 |
with torch.no_grad():
|
@@ -118,26 +133,41 @@ class LlavaLlamaModel(LlamaModel):
|
|
118 |
# variable length images
|
119 |
image_features = []
|
120 |
for image in images:
|
121 |
-
image_forward_out = vision_tower(
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
image_feature = select_hidden_state[:, 1:]
|
125 |
image_feature = image_feature.contiguous()
|
126 |
image_features.append(image_feature)
|
127 |
torch.cuda.empty_cache()
|
128 |
else:
|
129 |
image_forward_outs = vision_tower(images, output_hidden_states=True)
|
130 |
-
select_hidden_state_layer = getattr(
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
image_features = select_hidden_state[:, 1:]
|
133 |
image_features = image_features.contiguous()
|
134 |
torch.cuda.empty_cache()
|
135 |
|
136 |
if type(images) is list:
|
137 |
-
image_features = [
|
|
|
|
|
|
|
138 |
else:
|
139 |
image_features = self.mm_projector(image_features)
|
140 |
-
dummy_image_features = torch.zeros(
|
|
|
|
|
141 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
142 |
|
143 |
new_input_embeds = []
|
@@ -145,48 +175,128 @@ class LlavaLlamaModel(LlamaModel):
|
|
145 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
146 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
147 |
# multimodal LLM, but the current sample is not multimodal
|
148 |
-
cur_input_embeds =
|
|
|
|
|
149 |
new_input_embeds.append(cur_input_embeds)
|
150 |
cur_image_idx += 1
|
151 |
continue
|
152 |
if vision_tower.config.use_im_start_end:
|
153 |
cur_image_features = image_features[cur_image_idx]
|
154 |
num_patches = cur_image_features.shape[0]
|
155 |
-
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
for image_start_token_pos in image_start_tokens:
|
159 |
-
cur_image_features = image_features[cur_image_idx].to(
|
|
|
|
|
160 |
num_patches = cur_image_features.shape[0]
|
161 |
-
if
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
163 |
if orig_embeds_params is not None:
|
164 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
else:
|
166 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
cur_image_idx += 1
|
168 |
new_input_embeds.append(cur_new_input_embeds)
|
169 |
else:
|
170 |
cur_image_features = image_features[cur_image_idx]
|
171 |
num_patches = cur_image_features.shape[0]
|
172 |
-
if (
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
mask_index_start = masked_indices[0]
|
176 |
-
if (
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
if orig_embeds_params is not None:
|
179 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
else:
|
181 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
new_input_embeds.append(cur_new_input_embeds)
|
183 |
cur_image_idx += 1
|
184 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
185 |
return super(LlavaLlamaModel, self).forward(
|
186 |
-
input_ids=None,
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
190 |
)
|
191 |
|
192 |
|
@@ -218,11 +328,19 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
218 |
images: Optional[torch.FloatTensor] = None,
|
219 |
return_dict: Optional[bool] = None,
|
220 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
221 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
222 |
output_hidden_states = (
|
223 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
226 |
|
227 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
228 |
outputs = self.model(
|
@@ -234,7 +352,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
234 |
output_attentions=output_attentions,
|
235 |
output_hidden_states=output_hidden_states,
|
236 |
return_dict=return_dict,
|
237 |
-
images=images
|
238 |
)
|
239 |
hidden_states = outputs[0]
|
240 |
logits = self.lm_head(hidden_states)
|
@@ -269,7 +387,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
269 |
)
|
270 |
|
271 |
def prepare_inputs_for_generation(
|
272 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
273 |
):
|
274 |
if past_key_values:
|
275 |
input_ids = input_ids[:, -1:]
|
@@ -291,16 +414,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
291 |
)
|
292 |
return model_inputs
|
293 |
|
294 |
-
def initialize_vision_tokenizer(
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
vision_config = self.get_model().vision_tower[0].config
|
297 |
vision_config.use_im_start_end = mm_use_im_start_end
|
298 |
|
299 |
if mm_use_im_start_end:
|
300 |
# num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
301 |
-
|
302 |
# self.resize_token_embeddings(len(tokenizer))
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
# if num_new_tokens > 0:
|
306 |
# input_embeddings = self.get_input_embeddings().weight.data
|
@@ -315,24 +450,35 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
315 |
# output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
316 |
|
317 |
if tune_mm_mlp_adapter:
|
318 |
-
self.get_model().orig_embeds_params = [
|
|
|
|
|
319 |
for p in self.get_input_embeddings().parameters():
|
320 |
p.requires_grad = True
|
321 |
for p in self.get_output_embeddings().parameters():
|
322 |
p.requires_grad = False
|
323 |
|
324 |
if pretrain_mm_mlp_adapter:
|
325 |
-
mm_projector_weights = torch.load(
|
326 |
-
|
|
|
|
|
327 |
assert num_new_tokens == 2
|
328 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
329 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
|
|
|
|
330 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
331 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
332 |
else:
|
333 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
336 |
|
337 |
AutoConfig.register("llava", LlavaConfig)
|
338 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
|
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
from torch.nn import CrossEntropyLoss
|
22 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
|
23 |
+
CLIPVisionModel, LlamaConfig, LlamaForCausalLM,
|
24 |
+
LlamaModel)
|
25 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
26 |
+
CausalLMOutputWithPast)
|
|
|
|
|
27 |
|
28 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
29 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
43 |
|
44 |
if hasattr(config, "mm_vision_tower"):
|
45 |
# HACK: for FSDP
|
46 |
+
self.vision_tower = [
|
47 |
+
CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
48 |
+
]
|
49 |
|
50 |
if hasattr(config, "use_mm_proj"):
|
51 |
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
52 |
|
53 |
+
def initialize_vision_modules(
|
54 |
+
self,
|
55 |
+
vision_tower,
|
56 |
+
mm_vision_select_layer,
|
57 |
+
pretrain_mm_mlp_adapter=None,
|
58 |
+
tune_mm_mlp_adapter=False,
|
59 |
+
precision="bf16",
|
60 |
+
):
|
61 |
self.config.mm_vision_tower = vision_tower
|
62 |
|
63 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
64 |
|
65 |
+
if not hasattr(self, "vision_tower"):
|
66 |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
67 |
else:
|
68 |
vision_tower = self.vision_tower[0]
|
69 |
vision_tower.requires_grad_(False)
|
70 |
+
if precision == "bf16":
|
71 |
vision_tower = vision_tower.to(torch.bfloat16)
|
72 |
+
elif precision == "fp16":
|
73 |
vision_tower = vision_tower.to(torch.half)
|
74 |
else:
|
75 |
vision_tower = vision_tower.to(torch.float32)
|
|
|
83 |
self.config.mm_hidden_size = vision_config.hidden_size
|
84 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
85 |
|
86 |
+
if not hasattr(self, "mm_projector"):
|
87 |
+
self.mm_projector = nn.Linear(
|
88 |
+
vision_config.hidden_size, self.config.hidden_size
|
89 |
+
)
|
90 |
|
91 |
if pretrain_mm_mlp_adapter is not None:
|
92 |
+
mm_projector_weights = torch.load(
|
93 |
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
94 |
+
)
|
95 |
+
self.mm_projector.load_state_dict(
|
96 |
+
{k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
|
97 |
+
)
|
98 |
|
99 |
return dict(
|
100 |
image_processor=image_processor,
|
101 |
image_token_len=num_patches,
|
102 |
+
vision_config=vision_config,
|
103 |
)
|
104 |
|
105 |
def forward(
|
|
|
114 |
images: Optional[torch.FloatTensor] = None,
|
115 |
return_dict: Optional[bool] = None,
|
116 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
117 |
# HACK: replace back original embeddings for LLaVA pretraining
|
118 |
+
orig_embeds_params = getattr(self, "orig_embeds_params", None)
|
119 |
+
|
120 |
if inputs_embeds is None:
|
121 |
inputs_embeds = self.embed_tokens(input_ids)
|
122 |
|
123 |
+
vision_tower = getattr(self, "vision_tower", None)
|
124 |
+
if (
|
125 |
+
vision_tower is not None
|
126 |
+
and (input_ids.shape[1] != 1 or self.training)
|
127 |
+
and images is not None
|
128 |
+
):
|
129 |
# TODO: this is a modified multimodal LLM -- Haotian Liu
|
130 |
vision_tower = vision_tower[0] # HACK: for FSDP
|
131 |
with torch.no_grad():
|
|
|
133 |
# variable length images
|
134 |
image_features = []
|
135 |
for image in images:
|
136 |
+
image_forward_out = vision_tower(
|
137 |
+
image.unsqueeze(0), output_hidden_states=True
|
138 |
+
)
|
139 |
+
select_hidden_state_layer = getattr(
|
140 |
+
self.config, "mm_vision_select_layer", -1
|
141 |
+
)
|
142 |
+
select_hidden_state = image_forward_out.hidden_states[
|
143 |
+
select_hidden_state_layer
|
144 |
+
]
|
145 |
image_feature = select_hidden_state[:, 1:]
|
146 |
image_feature = image_feature.contiguous()
|
147 |
image_features.append(image_feature)
|
148 |
torch.cuda.empty_cache()
|
149 |
else:
|
150 |
image_forward_outs = vision_tower(images, output_hidden_states=True)
|
151 |
+
select_hidden_state_layer = getattr(
|
152 |
+
self.config, "mm_vision_select_layer", -1
|
153 |
+
)
|
154 |
+
select_hidden_state = image_forward_outs.hidden_states[
|
155 |
+
select_hidden_state_layer
|
156 |
+
]
|
157 |
image_features = select_hidden_state[:, 1:]
|
158 |
image_features = image_features.contiguous()
|
159 |
torch.cuda.empty_cache()
|
160 |
|
161 |
if type(images) is list:
|
162 |
+
image_features = [
|
163 |
+
self.mm_projector(image_feature)[0]
|
164 |
+
for image_feature in image_features
|
165 |
+
]
|
166 |
else:
|
167 |
image_features = self.mm_projector(image_features)
|
168 |
+
dummy_image_features = torch.zeros(
|
169 |
+
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
170 |
+
)
|
171 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
172 |
|
173 |
new_input_embeds = []
|
|
|
175 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
176 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
177 |
# multimodal LLM, but the current sample is not multimodal
|
178 |
+
cur_input_embeds = (
|
179 |
+
cur_input_embeds + (0.0 * dummy_image_features).sum()
|
180 |
+
)
|
181 |
new_input_embeds.append(cur_input_embeds)
|
182 |
cur_image_idx += 1
|
183 |
continue
|
184 |
if vision_tower.config.use_im_start_end:
|
185 |
cur_image_features = image_features[cur_image_idx]
|
186 |
num_patches = cur_image_features.shape[0]
|
187 |
+
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
|
188 |
+
cur_input_ids == vision_tower.config.im_end_token
|
189 |
+
).sum():
|
190 |
+
raise ValueError(
|
191 |
+
"The number of image start tokens and image end tokens should be the same."
|
192 |
+
)
|
193 |
+
image_start_tokens = torch.where(
|
194 |
+
cur_input_ids == vision_tower.config.im_start_token
|
195 |
+
)[0]
|
196 |
for image_start_token_pos in image_start_tokens:
|
197 |
+
cur_image_features = image_features[cur_image_idx].to(
|
198 |
+
device=cur_input_embeds.device
|
199 |
+
)
|
200 |
num_patches = cur_image_features.shape[0]
|
201 |
+
if (
|
202 |
+
cur_input_ids[image_start_token_pos + num_patches + 1]
|
203 |
+
!= vision_tower.config.im_end_token
|
204 |
+
):
|
205 |
+
raise ValueError(
|
206 |
+
"The image end token should follow the image start token."
|
207 |
+
)
|
208 |
if orig_embeds_params is not None:
|
209 |
+
cur_new_input_embeds = torch.cat(
|
210 |
+
(
|
211 |
+
cur_input_embeds[:image_start_token_pos].detach(),
|
212 |
+
cur_input_embeds[
|
213 |
+
image_start_token_pos : image_start_token_pos
|
214 |
+
+ 1
|
215 |
+
],
|
216 |
+
cur_image_features,
|
217 |
+
cur_input_embeds[
|
218 |
+
image_start_token_pos
|
219 |
+
+ num_patches
|
220 |
+
+ 1 : image_start_token_pos
|
221 |
+
+ num_patches
|
222 |
+
+ 2
|
223 |
+
],
|
224 |
+
cur_input_embeds[
|
225 |
+
image_start_token_pos + num_patches + 2 :
|
226 |
+
].detach(),
|
227 |
+
),
|
228 |
+
dim=0,
|
229 |
+
)
|
230 |
else:
|
231 |
+
cur_new_input_embeds = torch.cat(
|
232 |
+
(
|
233 |
+
cur_input_embeds[: image_start_token_pos + 1],
|
234 |
+
cur_image_features,
|
235 |
+
cur_input_embeds[
|
236 |
+
image_start_token_pos + num_patches + 1 :
|
237 |
+
],
|
238 |
+
),
|
239 |
+
dim=0,
|
240 |
+
)
|
241 |
cur_image_idx += 1
|
242 |
new_input_embeds.append(cur_new_input_embeds)
|
243 |
else:
|
244 |
cur_image_features = image_features[cur_image_idx]
|
245 |
num_patches = cur_image_features.shape[0]
|
246 |
+
if (
|
247 |
+
cur_input_ids == vision_tower.config.im_patch_token
|
248 |
+
).sum() != num_patches:
|
249 |
+
raise ValueError(
|
250 |
+
"The number of image patch tokens should be the same as the number of image patches."
|
251 |
+
)
|
252 |
+
masked_indices = torch.where(
|
253 |
+
cur_input_ids == vision_tower.config.im_patch_token
|
254 |
+
)[0]
|
255 |
mask_index_start = masked_indices[0]
|
256 |
+
if (
|
257 |
+
masked_indices
|
258 |
+
!= torch.arange(
|
259 |
+
mask_index_start,
|
260 |
+
mask_index_start + num_patches,
|
261 |
+
device=masked_indices.device,
|
262 |
+
dtype=masked_indices.dtype,
|
263 |
+
)
|
264 |
+
).any():
|
265 |
+
raise ValueError(
|
266 |
+
"The image patch tokens should be consecutive."
|
267 |
+
)
|
268 |
if orig_embeds_params is not None:
|
269 |
+
cur_new_input_embeds = torch.cat(
|
270 |
+
(
|
271 |
+
cur_input_embeds[:mask_index_start].detach(),
|
272 |
+
cur_image_features,
|
273 |
+
cur_input_embeds[
|
274 |
+
mask_index_start + num_patches :
|
275 |
+
].detach(),
|
276 |
+
),
|
277 |
+
dim=0,
|
278 |
+
)
|
279 |
else:
|
280 |
+
cur_new_input_embeds = torch.cat(
|
281 |
+
(
|
282 |
+
cur_input_embeds[:mask_index_start],
|
283 |
+
cur_image_features,
|
284 |
+
cur_input_embeds[mask_index_start + num_patches :],
|
285 |
+
),
|
286 |
+
dim=0,
|
287 |
+
)
|
288 |
new_input_embeds.append(cur_new_input_embeds)
|
289 |
cur_image_idx += 1
|
290 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
291 |
return super(LlavaLlamaModel, self).forward(
|
292 |
+
input_ids=None,
|
293 |
+
attention_mask=attention_mask,
|
294 |
+
past_key_values=past_key_values,
|
295 |
+
inputs_embeds=inputs_embeds,
|
296 |
+
use_cache=use_cache,
|
297 |
+
output_attentions=output_attentions,
|
298 |
+
output_hidden_states=output_hidden_states,
|
299 |
+
return_dict=return_dict,
|
300 |
)
|
301 |
|
302 |
|
|
|
328 |
images: Optional[torch.FloatTensor] = None,
|
329 |
return_dict: Optional[bool] = None,
|
330 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
331 |
+
output_attentions = (
|
332 |
+
output_attentions
|
333 |
+
if output_attentions is not None
|
334 |
+
else self.config.output_attentions
|
335 |
+
)
|
336 |
output_hidden_states = (
|
337 |
+
output_hidden_states
|
338 |
+
if output_hidden_states is not None
|
339 |
+
else self.config.output_hidden_states
|
340 |
+
)
|
341 |
+
return_dict = (
|
342 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
343 |
)
|
|
|
344 |
|
345 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
346 |
outputs = self.model(
|
|
|
352 |
output_attentions=output_attentions,
|
353 |
output_hidden_states=output_hidden_states,
|
354 |
return_dict=return_dict,
|
355 |
+
images=images,
|
356 |
)
|
357 |
hidden_states = outputs[0]
|
358 |
logits = self.lm_head(hidden_states)
|
|
|
387 |
)
|
388 |
|
389 |
def prepare_inputs_for_generation(
|
390 |
+
self,
|
391 |
+
input_ids,
|
392 |
+
past_key_values=None,
|
393 |
+
attention_mask=None,
|
394 |
+
inputs_embeds=None,
|
395 |
+
**kwargs,
|
396 |
):
|
397 |
if past_key_values:
|
398 |
input_ids = input_ids[:, -1:]
|
|
|
414 |
)
|
415 |
return model_inputs
|
416 |
|
417 |
+
def initialize_vision_tokenizer(
|
418 |
+
self,
|
419 |
+
mm_use_im_start_end,
|
420 |
+
tokenizer,
|
421 |
+
num_new_tokens,
|
422 |
+
device,
|
423 |
+
tune_mm_mlp_adapter=False,
|
424 |
+
pretrain_mm_mlp_adapter=None,
|
425 |
+
):
|
426 |
vision_config = self.get_model().vision_tower[0].config
|
427 |
vision_config.use_im_start_end = mm_use_im_start_end
|
428 |
|
429 |
if mm_use_im_start_end:
|
430 |
# num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
431 |
+
|
432 |
# self.resize_token_embeddings(len(tokenizer))
|
433 |
+
(
|
434 |
+
vision_config.im_start_token,
|
435 |
+
vision_config.im_end_token,
|
436 |
+
) = tokenizer.convert_tokens_to_ids(
|
437 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
438 |
+
)
|
439 |
|
440 |
# if num_new_tokens > 0:
|
441 |
# input_embeddings = self.get_input_embeddings().weight.data
|
|
|
450 |
# output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
451 |
|
452 |
if tune_mm_mlp_adapter:
|
453 |
+
self.get_model().orig_embeds_params = [
|
454 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)
|
455 |
+
]
|
456 |
for p in self.get_input_embeddings().parameters():
|
457 |
p.requires_grad = True
|
458 |
for p in self.get_output_embeddings().parameters():
|
459 |
p.requires_grad = False
|
460 |
|
461 |
if pretrain_mm_mlp_adapter:
|
462 |
+
mm_projector_weights = torch.load(
|
463 |
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
464 |
+
)
|
465 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
466 |
assert num_new_tokens == 2
|
467 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
468 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
469 |
+
-num_new_tokens:
|
470 |
+
]
|
471 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
472 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
473 |
else:
|
474 |
+
raise ValueError(
|
475 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
|
476 |
+
)
|
477 |
+
|
478 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
479 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
480 |
+
)[0]
|
481 |
|
|
|
482 |
|
483 |
AutoConfig.register("llava", LlavaConfig)
|
484 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
model/llava/model/llava_mpt.py
CHANGED
@@ -13,24 +13,21 @@
|
|
13 |
# limitations under the License.
|
14 |
|
15 |
|
16 |
-
|
17 |
import warnings
|
|
|
18 |
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
import torch.nn.functional as F
|
22 |
from torch.nn import CrossEntropyLoss
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
CLIPVisionModel, CLIPImageProcessor
|
28 |
-
|
29 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
30 |
|
31 |
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
32 |
|
33 |
-
|
34 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
35 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
36 |
DEFAULT_IM_START_TOKEN = "<im_start>"
|
@@ -49,19 +46,26 @@ class LlavaMPTModel(MPTModel):
|
|
49 |
|
50 |
if hasattr(config, "mm_vision_tower"):
|
51 |
# HACK: for FSDP
|
52 |
-
self.vision_tower = [
|
|
|
|
|
53 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
54 |
|
55 |
if hasattr(config, "use_mm_proj"):
|
56 |
self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
|
57 |
|
58 |
-
def initialize_vision_modules(
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
60 |
self.config.mm_vision_tower = vision_tower
|
61 |
|
62 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
63 |
|
64 |
-
if not hasattr(self,
|
65 |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
66 |
else:
|
67 |
vision_tower = self.vision_tower[0]
|
@@ -76,23 +80,44 @@ class LlavaMPTModel(MPTModel):
|
|
76 |
self.config.mm_hidden_size = vision_config.hidden_size
|
77 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
78 |
|
79 |
-
if not hasattr(self,
|
80 |
-
self.mm_projector = nn.Linear(
|
|
|
|
|
81 |
|
82 |
if pretrain_mm_mlp_adapter is not None:
|
83 |
-
mm_projector_weights = torch.load(
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
return dict(
|
87 |
image_processor=image_processor,
|
88 |
image_token_len=num_patches,
|
89 |
-
vision_config=vision_config
|
90 |
)
|
91 |
|
92 |
-
def forward(
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
# HACK: replace back original embeddings for LLaVA pretraining
|
95 |
-
orig_embeds_params = getattr(self,
|
96 |
# if orig_embeds_params is not None:
|
97 |
# orig_embeds_params = orig_embeds_params[0]
|
98 |
# with torch.no_grad():
|
@@ -100,8 +125,12 @@ class LlavaMPTModel(MPTModel):
|
|
100 |
|
101 |
inputs_embeds = self.wte(input_ids)
|
102 |
|
103 |
-
vision_tower = getattr(self,
|
104 |
-
if
|
|
|
|
|
|
|
|
|
105 |
# TODO: this is a modified multimodal LLM -- Haotian Liu
|
106 |
vision_tower = vision_tower[0] # HACK: for FSDP
|
107 |
with torch.no_grad():
|
@@ -109,21 +138,36 @@ class LlavaMPTModel(MPTModel):
|
|
109 |
# variable length images
|
110 |
image_features = []
|
111 |
for image in images:
|
112 |
-
image_forward_out = vision_tower(
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
image_feature = select_hidden_state[:, 1:]
|
116 |
image_features.append(image_feature)
|
117 |
else:
|
118 |
image_forward_outs = vision_tower(images, output_hidden_states=True)
|
119 |
-
select_hidden_state_layer = getattr(
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
image_features = select_hidden_state[:, 1:]
|
122 |
if type(images) is list:
|
123 |
-
image_features = [
|
|
|
|
|
|
|
124 |
else:
|
125 |
image_features = self.mm_projector(image_features)
|
126 |
-
dummy_image_features = torch.zeros(
|
|
|
|
|
127 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
128 |
|
129 |
new_input_embeds = []
|
@@ -131,43 +175,130 @@ class LlavaMPTModel(MPTModel):
|
|
131 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
132 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
133 |
# multimodal LLM, but the current sample is not multimodal
|
134 |
-
cur_input_embeds =
|
|
|
|
|
135 |
new_input_embeds.append(cur_input_embeds)
|
136 |
continue
|
137 |
if vision_tower.config.use_im_start_end:
|
138 |
cur_image_features = image_features[cur_image_idx]
|
139 |
num_patches = cur_image_features.shape[0]
|
140 |
-
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
for image_start_token_pos in image_start_tokens:
|
144 |
-
cur_image_features = image_features[cur_image_idx].to(
|
|
|
|
|
145 |
num_patches = cur_image_features.shape[0]
|
146 |
-
if
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
148 |
if orig_embeds_params is not None:
|
149 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
else:
|
151 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
cur_image_idx += 1
|
153 |
new_input_embeds.append(cur_new_input_embeds)
|
154 |
else:
|
155 |
cur_image_features = image_features[cur_image_idx]
|
156 |
num_patches = cur_image_features.shape[0]
|
157 |
-
if (
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
mask_index_start = masked_indices[0]
|
161 |
-
if (
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if orig_embeds_params is not None:
|
164 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
else:
|
166 |
-
cur_new_input_embeds = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
new_input_embeds.append(cur_new_input_embeds)
|
168 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
169 |
|
170 |
-
return super(LlavaMPTModel, self).forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
|
173 |
class LlavaMPTForCausalLM(MPTForCausalLM):
|
@@ -178,16 +309,18 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
|
|
178 |
super(MPTForCausalLM, self).__init__(config)
|
179 |
|
180 |
if not config.tie_word_embeddings:
|
181 |
-
raise ValueError(
|
182 |
self.transformer = LlavaMPTModel(config)
|
183 |
self.logit_scale = None
|
184 |
if config.logit_scale is not None:
|
185 |
logit_scale = config.logit_scale
|
186 |
if isinstance(logit_scale, str):
|
187 |
-
if logit_scale ==
|
188 |
logit_scale = 1 / math.sqrt(config.d_model)
|
189 |
else:
|
190 |
-
raise ValueError(
|
|
|
|
|
191 |
self.logit_scale = logit_scale
|
192 |
|
193 |
def get_model(self):
|
@@ -197,28 +330,67 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
|
|
197 |
if isinstance(module, LlavaMPTModel):
|
198 |
module.gradient_checkpointing = value
|
199 |
|
200 |
-
def forward(
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
203 |
-
outputs = self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
205 |
if self.logit_scale is not None:
|
206 |
if self.logit_scale == 0:
|
207 |
-
warnings.warn(
|
|
|
|
|
208 |
logits *= self.logit_scale
|
209 |
loss = None
|
210 |
if labels is not None:
|
211 |
labels = torch.roll(labels, shifts=-1)
|
212 |
labels[:, -1] = -100
|
213 |
-
loss = F.cross_entropy(
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
def prepare_inputs_for_generation(
|
|
|
|
|
217 |
if inputs_embeds is not None:
|
218 |
-
raise NotImplementedError(
|
219 |
-
attention_mask = kwargs[
|
220 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
221 |
-
raise NotImplementedError(
|
|
|
|
|
222 |
if self.transformer.attn_uses_sequence_id and self.training:
|
223 |
sequence_id = torch.zeros_like(input_ids[:1])
|
224 |
else:
|
@@ -227,55 +399,91 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
|
|
227 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
228 |
if self.transformer.prefix_lm:
|
229 |
prefix_mask = torch.ones_like(attention_mask)
|
230 |
-
if kwargs.get(
|
231 |
-
raise NotImplementedError(
|
|
|
|
|
232 |
else:
|
233 |
prefix_mask = None
|
234 |
-
return {
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
vision_config = self.get_model().vision_tower[0].config
|
239 |
vision_config.use_im_start_end = mm_use_im_start_end
|
240 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
241 |
self.resize_token_embeddings(len(tokenizer))
|
242 |
|
243 |
if mm_use_im_start_end:
|
244 |
-
num_new_tokens = tokenizer.add_tokens(
|
|
|
|
|
245 |
self.resize_token_embeddings(len(tokenizer))
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
if num_new_tokens > 0:
|
249 |
input_embeddings = self.get_input_embeddings().weight.data
|
250 |
output_embeddings = self.get_output_embeddings().weight.data
|
251 |
|
252 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
253 |
-
dim=0, keepdim=True
|
|
|
254 |
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
255 |
-
dim=0, keepdim=True
|
|
|
256 |
|
257 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
258 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
259 |
|
260 |
if tune_mm_mlp_adapter:
|
261 |
-
self.get_model().orig_embeds_params = [
|
|
|
|
|
262 |
for p in self.get_input_embeddings().parameters():
|
263 |
p.requires_grad = True
|
264 |
for p in self.get_output_embeddings().parameters():
|
265 |
p.requires_grad = False
|
266 |
|
267 |
if pretrain_mm_mlp_adapter:
|
268 |
-
mm_projector_weights = torch.load(
|
269 |
-
|
|
|
|
|
270 |
assert num_new_tokens == 2
|
271 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
272 |
-
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
|
|
|
|
273 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
274 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
275 |
else:
|
276 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
279 |
|
280 |
AutoConfig.register("llava_mpt", LlavaMPTConfig)
|
281 |
AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
|
|
|
13 |
# limitations under the License.
|
14 |
|
15 |
|
16 |
+
import math
|
17 |
import warnings
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
|
20 |
import torch
|
21 |
import torch.nn as nn
|
22 |
import torch.nn.functional as F
|
23 |
from torch.nn import CrossEntropyLoss
|
24 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
|
25 |
+
CLIPVisionModel)
|
26 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
27 |
+
CausalLMOutputWithPast)
|
|
|
|
|
|
|
28 |
|
29 |
from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
|
30 |
|
|
|
31 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
32 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
33 |
DEFAULT_IM_START_TOKEN = "<im_start>"
|
|
|
46 |
|
47 |
if hasattr(config, "mm_vision_tower"):
|
48 |
# HACK: for FSDP
|
49 |
+
self.vision_tower = [
|
50 |
+
CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
51 |
+
]
|
52 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
53 |
|
54 |
if hasattr(config, "use_mm_proj"):
|
55 |
self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
|
56 |
|
57 |
+
def initialize_vision_modules(
|
58 |
+
self,
|
59 |
+
vision_tower,
|
60 |
+
mm_vision_select_layer,
|
61 |
+
pretrain_mm_mlp_adapter=None,
|
62 |
+
tune_mm_mlp_adapter=False,
|
63 |
+
):
|
64 |
self.config.mm_vision_tower = vision_tower
|
65 |
|
66 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
67 |
|
68 |
+
if not hasattr(self, "vision_tower"):
|
69 |
vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
70 |
else:
|
71 |
vision_tower = self.vision_tower[0]
|
|
|
80 |
self.config.mm_hidden_size = vision_config.hidden_size
|
81 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
82 |
|
83 |
+
if not hasattr(self, "mm_projector"):
|
84 |
+
self.mm_projector = nn.Linear(
|
85 |
+
vision_config.hidden_size, self.config.d_model
|
86 |
+
)
|
87 |
|
88 |
if pretrain_mm_mlp_adapter is not None:
|
89 |
+
mm_projector_weights = torch.load(
|
90 |
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
91 |
+
)
|
92 |
+
self.mm_projector.load_state_dict(
|
93 |
+
{
|
94 |
+
k.split(".")[-1]: v
|
95 |
+
for k, v in mm_projector_weights.items()
|
96 |
+
if "mm_projector" in k
|
97 |
+
}
|
98 |
+
)
|
99 |
|
100 |
return dict(
|
101 |
image_processor=image_processor,
|
102 |
image_token_len=num_patches,
|
103 |
+
vision_config=vision_config,
|
104 |
)
|
105 |
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
input_ids: torch.LongTensor,
|
109 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
110 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
111 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
112 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
113 |
+
return_dict: Optional[bool] = None,
|
114 |
+
output_attentions: Optional[bool] = None,
|
115 |
+
output_hidden_states: Optional[bool] = None,
|
116 |
+
use_cache: Optional[bool] = None,
|
117 |
+
images=None,
|
118 |
+
):
|
119 |
# HACK: replace back original embeddings for LLaVA pretraining
|
120 |
+
orig_embeds_params = getattr(self, "orig_embeds_params", None)
|
121 |
# if orig_embeds_params is not None:
|
122 |
# orig_embeds_params = orig_embeds_params[0]
|
123 |
# with torch.no_grad():
|
|
|
125 |
|
126 |
inputs_embeds = self.wte(input_ids)
|
127 |
|
128 |
+
vision_tower = getattr(self, "vision_tower", None)
|
129 |
+
if (
|
130 |
+
vision_tower is not None
|
131 |
+
and (input_ids.shape[1] != 1 or self.training)
|
132 |
+
and images is not None
|
133 |
+
):
|
134 |
# TODO: this is a modified multimodal LLM -- Haotian Liu
|
135 |
vision_tower = vision_tower[0] # HACK: for FSDP
|
136 |
with torch.no_grad():
|
|
|
138 |
# variable length images
|
139 |
image_features = []
|
140 |
for image in images:
|
141 |
+
image_forward_out = vision_tower(
|
142 |
+
image.unsqueeze(0), output_hidden_states=True
|
143 |
+
)
|
144 |
+
select_hidden_state_layer = getattr(
|
145 |
+
self.config, "mm_vision_select_layer", -1
|
146 |
+
)
|
147 |
+
select_hidden_state = image_forward_out.hidden_states[
|
148 |
+
select_hidden_state_layer
|
149 |
+
]
|
150 |
image_feature = select_hidden_state[:, 1:]
|
151 |
image_features.append(image_feature)
|
152 |
else:
|
153 |
image_forward_outs = vision_tower(images, output_hidden_states=True)
|
154 |
+
select_hidden_state_layer = getattr(
|
155 |
+
self.config, "mm_vision_select_layer", -1
|
156 |
+
)
|
157 |
+
select_hidden_state = image_forward_outs.hidden_states[
|
158 |
+
select_hidden_state_layer
|
159 |
+
]
|
160 |
image_features = select_hidden_state[:, 1:]
|
161 |
if type(images) is list:
|
162 |
+
image_features = [
|
163 |
+
self.mm_projector(image_feature)[0]
|
164 |
+
for image_feature in image_features
|
165 |
+
]
|
166 |
else:
|
167 |
image_features = self.mm_projector(image_features)
|
168 |
+
dummy_image_features = torch.zeros(
|
169 |
+
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype
|
170 |
+
)
|
171 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
172 |
|
173 |
new_input_embeds = []
|
|
|
175 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
176 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
177 |
# multimodal LLM, but the current sample is not multimodal
|
178 |
+
cur_input_embeds = (
|
179 |
+
cur_input_embeds + (0.0 * dummy_image_features).sum()
|
180 |
+
)
|
181 |
new_input_embeds.append(cur_input_embeds)
|
182 |
continue
|
183 |
if vision_tower.config.use_im_start_end:
|
184 |
cur_image_features = image_features[cur_image_idx]
|
185 |
num_patches = cur_image_features.shape[0]
|
186 |
+
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
|
187 |
+
cur_input_ids == vision_tower.config.im_end_token
|
188 |
+
).sum():
|
189 |
+
raise ValueError(
|
190 |
+
"The number of image start tokens and image end tokens should be the same."
|
191 |
+
)
|
192 |
+
image_start_tokens = torch.where(
|
193 |
+
cur_input_ids == vision_tower.config.im_start_token
|
194 |
+
)[0]
|
195 |
for image_start_token_pos in image_start_tokens:
|
196 |
+
cur_image_features = image_features[cur_image_idx].to(
|
197 |
+
device=cur_input_embeds.device
|
198 |
+
)
|
199 |
num_patches = cur_image_features.shape[0]
|
200 |
+
if (
|
201 |
+
cur_input_ids[image_start_token_pos + num_patches + 1]
|
202 |
+
!= vision_tower.config.im_end_token
|
203 |
+
):
|
204 |
+
raise ValueError(
|
205 |
+
"The image end token should follow the image start token."
|
206 |
+
)
|
207 |
if orig_embeds_params is not None:
|
208 |
+
cur_new_input_embeds = torch.cat(
|
209 |
+
(
|
210 |
+
cur_input_embeds[:image_start_token_pos].detach(),
|
211 |
+
cur_input_embeds[
|
212 |
+
image_start_token_pos : image_start_token_pos
|
213 |
+
+ 1
|
214 |
+
],
|
215 |
+
cur_image_features,
|
216 |
+
cur_input_embeds[
|
217 |
+
image_start_token_pos
|
218 |
+
+ num_patches
|
219 |
+
+ 1 : image_start_token_pos
|
220 |
+
+ num_patches
|
221 |
+
+ 2
|
222 |
+
],
|
223 |
+
cur_input_embeds[
|
224 |
+
image_start_token_pos + num_patches + 2 :
|
225 |
+
].detach(),
|
226 |
+
),
|
227 |
+
dim=0,
|
228 |
+
)
|
229 |
else:
|
230 |
+
cur_new_input_embeds = torch.cat(
|
231 |
+
(
|
232 |
+
cur_input_embeds[: image_start_token_pos + 1],
|
233 |
+
cur_image_features,
|
234 |
+
cur_input_embeds[
|
235 |
+
image_start_token_pos + num_patches + 1 :
|
236 |
+
],
|
237 |
+
),
|
238 |
+
dim=0,
|
239 |
+
)
|
240 |
cur_image_idx += 1
|
241 |
new_input_embeds.append(cur_new_input_embeds)
|
242 |
else:
|
243 |
cur_image_features = image_features[cur_image_idx]
|
244 |
num_patches = cur_image_features.shape[0]
|
245 |
+
if (
|
246 |
+
cur_input_ids == vision_tower.config.im_patch_token
|
247 |
+
).sum() != num_patches:
|
248 |
+
raise ValueError(
|
249 |
+
"The number of image patch tokens should be the same as the number of image patches."
|
250 |
+
)
|
251 |
+
masked_indices = torch.where(
|
252 |
+
cur_input_ids == vision_tower.config.im_patch_token
|
253 |
+
)[0]
|
254 |
mask_index_start = masked_indices[0]
|
255 |
+
if (
|
256 |
+
masked_indices
|
257 |
+
!= torch.arange(
|
258 |
+
mask_index_start,
|
259 |
+
mask_index_start + num_patches,
|
260 |
+
device=masked_indices.device,
|
261 |
+
dtype=masked_indices.dtype,
|
262 |
+
)
|
263 |
+
).any():
|
264 |
+
raise ValueError(
|
265 |
+
"The image patch tokens should be consecutive."
|
266 |
+
)
|
267 |
if orig_embeds_params is not None:
|
268 |
+
cur_new_input_embeds = torch.cat(
|
269 |
+
(
|
270 |
+
cur_input_embeds[:mask_index_start].detach(),
|
271 |
+
cur_image_features,
|
272 |
+
cur_input_embeds[
|
273 |
+
mask_index_start + num_patches :
|
274 |
+
].detach(),
|
275 |
+
),
|
276 |
+
dim=0,
|
277 |
+
)
|
278 |
else:
|
279 |
+
cur_new_input_embeds = torch.cat(
|
280 |
+
(
|
281 |
+
cur_input_embeds[:mask_index_start],
|
282 |
+
cur_image_features,
|
283 |
+
cur_input_embeds[mask_index_start + num_patches :],
|
284 |
+
),
|
285 |
+
dim=0,
|
286 |
+
)
|
287 |
new_input_embeds.append(cur_new_input_embeds)
|
288 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
289 |
|
290 |
+
return super(LlavaMPTModel, self).forward(
|
291 |
+
input_ids=None,
|
292 |
+
past_key_values=past_key_values,
|
293 |
+
attention_mask=attention_mask,
|
294 |
+
prefix_mask=prefix_mask,
|
295 |
+
sequence_id=sequence_id,
|
296 |
+
return_dict=return_dict,
|
297 |
+
output_attentions=output_attentions,
|
298 |
+
output_hidden_states=output_hidden_states,
|
299 |
+
use_cache=use_cache,
|
300 |
+
tok_emb=inputs_embeds,
|
301 |
+
)
|
302 |
|
303 |
|
304 |
class LlavaMPTForCausalLM(MPTForCausalLM):
|
|
|
309 |
super(MPTForCausalLM, self).__init__(config)
|
310 |
|
311 |
if not config.tie_word_embeddings:
|
312 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
313 |
self.transformer = LlavaMPTModel(config)
|
314 |
self.logit_scale = None
|
315 |
if config.logit_scale is not None:
|
316 |
logit_scale = config.logit_scale
|
317 |
if isinstance(logit_scale, str):
|
318 |
+
if logit_scale == "inv_sqrt_d_model":
|
319 |
logit_scale = 1 / math.sqrt(config.d_model)
|
320 |
else:
|
321 |
+
raise ValueError(
|
322 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
323 |
+
)
|
324 |
self.logit_scale = logit_scale
|
325 |
|
326 |
def get_model(self):
|
|
|
330 |
if isinstance(module, LlavaMPTModel):
|
331 |
module.gradient_checkpointing = value
|
332 |
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
input_ids: torch.LongTensor,
|
336 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
337 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
338 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
339 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
340 |
+
labels: Optional[torch.LongTensor] = None,
|
341 |
+
return_dict: Optional[bool] = None,
|
342 |
+
output_attentions: Optional[bool] = None,
|
343 |
+
output_hidden_states: Optional[bool] = None,
|
344 |
+
use_cache: Optional[bool] = None,
|
345 |
+
images=None,
|
346 |
+
):
|
347 |
+
return_dict = (
|
348 |
+
return_dict if return_dict is not None else self.config.return_dict
|
349 |
+
)
|
350 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
351 |
+
outputs = self.transformer(
|
352 |
+
input_ids=input_ids,
|
353 |
+
past_key_values=past_key_values,
|
354 |
+
attention_mask=attention_mask,
|
355 |
+
prefix_mask=prefix_mask,
|
356 |
+
sequence_id=sequence_id,
|
357 |
+
return_dict=return_dict,
|
358 |
+
output_attentions=output_attentions,
|
359 |
+
output_hidden_states=output_hidden_states,
|
360 |
+
use_cache=use_cache,
|
361 |
+
images=images,
|
362 |
+
)
|
363 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
364 |
if self.logit_scale is not None:
|
365 |
if self.logit_scale == 0:
|
366 |
+
warnings.warn(
|
367 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
368 |
+
)
|
369 |
logits *= self.logit_scale
|
370 |
loss = None
|
371 |
if labels is not None:
|
372 |
labels = torch.roll(labels, shifts=-1)
|
373 |
labels[:, -1] = -100
|
374 |
+
loss = F.cross_entropy(
|
375 |
+
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
376 |
+
)
|
377 |
+
return CausalLMOutputWithPast(
|
378 |
+
loss=loss,
|
379 |
+
logits=logits,
|
380 |
+
past_key_values=outputs.past_key_values,
|
381 |
+
hidden_states=outputs.hidden_states,
|
382 |
+
)
|
383 |
|
384 |
+
def prepare_inputs_for_generation(
|
385 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
386 |
+
):
|
387 |
if inputs_embeds is not None:
|
388 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
389 |
+
attention_mask = kwargs["attention_mask"].bool()
|
390 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
391 |
+
raise NotImplementedError(
|
392 |
+
"MPT does not support generation with right padding."
|
393 |
+
)
|
394 |
if self.transformer.attn_uses_sequence_id and self.training:
|
395 |
sequence_id = torch.zeros_like(input_ids[:1])
|
396 |
else:
|
|
|
399 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
400 |
if self.transformer.prefix_lm:
|
401 |
prefix_mask = torch.ones_like(attention_mask)
|
402 |
+
if kwargs.get("use_cache") == False:
|
403 |
+
raise NotImplementedError(
|
404 |
+
"MPT with prefix_lm=True does not support use_cache=False."
|
405 |
+
)
|
406 |
else:
|
407 |
prefix_mask = None
|
408 |
+
return {
|
409 |
+
"input_ids": input_ids,
|
410 |
+
"attention_mask": attention_mask,
|
411 |
+
"prefix_mask": prefix_mask,
|
412 |
+
"sequence_id": sequence_id,
|
413 |
+
"past_key_values": past_key_values,
|
414 |
+
"use_cache": kwargs.get("use_cache", True),
|
415 |
+
"images": kwargs.get("images", None),
|
416 |
+
}
|
417 |
+
|
418 |
+
def initialize_vision_tokenizer(
|
419 |
+
self,
|
420 |
+
mm_use_im_start_end,
|
421 |
+
tokenizer,
|
422 |
+
device,
|
423 |
+
tune_mm_mlp_adapter=False,
|
424 |
+
pretrain_mm_mlp_adapter=None,
|
425 |
+
):
|
426 |
vision_config = self.get_model().vision_tower[0].config
|
427 |
vision_config.use_im_start_end = mm_use_im_start_end
|
428 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
429 |
self.resize_token_embeddings(len(tokenizer))
|
430 |
|
431 |
if mm_use_im_start_end:
|
432 |
+
num_new_tokens = tokenizer.add_tokens(
|
433 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
434 |
+
)
|
435 |
self.resize_token_embeddings(len(tokenizer))
|
436 |
+
(
|
437 |
+
vision_config.im_start_token,
|
438 |
+
vision_config.im_end_token,
|
439 |
+
) = tokenizer.convert_tokens_to_ids(
|
440 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
441 |
+
)
|
442 |
|
443 |
if num_new_tokens > 0:
|
444 |
input_embeddings = self.get_input_embeddings().weight.data
|
445 |
output_embeddings = self.get_output_embeddings().weight.data
|
446 |
|
447 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
448 |
+
dim=0, keepdim=True
|
449 |
+
)
|
450 |
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
451 |
+
dim=0, keepdim=True
|
452 |
+
)
|
453 |
|
454 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
455 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
456 |
|
457 |
if tune_mm_mlp_adapter:
|
458 |
+
self.get_model().orig_embeds_params = [
|
459 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)
|
460 |
+
]
|
461 |
for p in self.get_input_embeddings().parameters():
|
462 |
p.requires_grad = True
|
463 |
for p in self.get_output_embeddings().parameters():
|
464 |
p.requires_grad = False
|
465 |
|
466 |
if pretrain_mm_mlp_adapter:
|
467 |
+
mm_projector_weights = torch.load(
|
468 |
+
pretrain_mm_mlp_adapter, map_location="cpu"
|
469 |
+
)
|
470 |
+
embed_tokens_weight = mm_projector_weights["transformer.wte.weight"]
|
471 |
assert num_new_tokens == 2
|
472 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
473 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
|
474 |
+
-num_new_tokens:
|
475 |
+
]
|
476 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
477 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
478 |
else:
|
479 |
+
raise ValueError(
|
480 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
|
481 |
+
)
|
482 |
+
|
483 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
484 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
485 |
+
)[0]
|
486 |
|
|
|
487 |
|
488 |
AutoConfig.register("llava_mpt", LlavaMPTConfig)
|
489 |
AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
|
model/llava/model/make_delta.py
CHANGED
@@ -5,31 +5,40 @@ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/mod
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
8 |
-
from tqdm import tqdm
|
9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
from llava.model.utils import auto_upgrade
|
|
|
|
|
11 |
|
12 |
|
13 |
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
14 |
print("Loading base model")
|
15 |
base = AutoModelForCausalLM.from_pretrained(
|
16 |
-
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
|
|
17 |
|
18 |
print("Loading target model")
|
19 |
auto_upgrade(target_model_path)
|
20 |
-
target = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
21 |
|
22 |
print("Calculating delta")
|
23 |
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
24 |
if name not in base.state_dict():
|
25 |
-
assert name in [
|
|
|
|
|
|
|
26 |
continue
|
27 |
if param.data.shape == base.state_dict()[name].shape:
|
28 |
param.data -= base.state_dict()[name]
|
29 |
else:
|
30 |
-
assert name in [
|
|
|
|
|
|
|
31 |
bparam = base.state_dict()[name]
|
32 |
-
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
33 |
|
34 |
print("Saving delta")
|
35 |
if hub_repo_id:
|
@@ -49,4 +58,6 @@ if __name__ == "__main__":
|
|
49 |
parser.add_argument("--hub-repo-id", type=str, default=None)
|
50 |
args = parser.parse_args()
|
51 |
|
52 |
-
make_delta(
|
|
|
|
|
|
5 |
import argparse
|
6 |
|
7 |
import torch
|
|
|
|
|
8 |
from llava.model.utils import auto_upgrade
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
|
12 |
|
13 |
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
14 |
print("Loading base model")
|
15 |
base = AutoModelForCausalLM.from_pretrained(
|
16 |
+
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
17 |
+
)
|
18 |
|
19 |
print("Loading target model")
|
20 |
auto_upgrade(target_model_path)
|
21 |
+
target = AutoModelForCausalLM.from_pretrained(
|
22 |
+
target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
23 |
+
)
|
24 |
|
25 |
print("Calculating delta")
|
26 |
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
27 |
if name not in base.state_dict():
|
28 |
+
assert name in [
|
29 |
+
"model.mm_projector.weight",
|
30 |
+
"model.mm_projector.bias",
|
31 |
+
], f"{name} not in base model"
|
32 |
continue
|
33 |
if param.data.shape == base.state_dict()[name].shape:
|
34 |
param.data -= base.state_dict()[name]
|
35 |
else:
|
36 |
+
assert name in [
|
37 |
+
"model.embed_tokens.weight",
|
38 |
+
"lm_head.weight",
|
39 |
+
], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
|
40 |
bparam = base.state_dict()[name]
|
41 |
+
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
|
42 |
|
43 |
print("Saving delta")
|
44 |
if hub_repo_id:
|
|
|
58 |
parser.add_argument("--hub-repo-id", type=str, default=None)
|
59 |
args = parser.parse_args()
|
60 |
|
61 |
+
make_delta(
|
62 |
+
args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
|
63 |
+
)
|
model/llava/model/mpt/adapt_tokenizer.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
from typing import Union
|
2 |
-
|
|
|
|
|
|
|
3 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
4 |
NUM_SENTINEL_TOKENS: int = 100
|
5 |
|
|
|
6 |
def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
7 |
"""Adds sentinel tokens and padding token (if missing).
|
8 |
|
@@ -12,16 +16,17 @@ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
|
12 |
All added tokens are added as special tokens. No tokens are
|
13 |
added if sentinel tokens and padding token already exist.
|
14 |
"""
|
15 |
-
sentinels_to_add = [f
|
16 |
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
17 |
if tokenizer.pad_token is None:
|
18 |
-
tokenizer.add_tokens(
|
19 |
-
tokenizer.pad_token =
|
20 |
assert tokenizer.pad_token_id is not None
|
21 |
-
sentinels =
|
22 |
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
23 |
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
24 |
|
|
|
25 |
class AutoTokenizerForMOD(AutoTokenizer):
|
26 |
"""AutoTokenizer + Adaptation for MOD.
|
27 |
|
@@ -38,4 +43,4 @@ class AutoTokenizerForMOD(AutoTokenizer):
|
|
38 |
"""See `AutoTokenizer.from_pretrained` docstring."""
|
39 |
tokenizer = super().from_pretrained(*args, **kwargs)
|
40 |
adapt_tokenizer_for_denoising(tokenizer)
|
41 |
-
return tokenizer
|
|
|
1 |
from typing import Union
|
2 |
+
|
3 |
+
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
4 |
+
PreTrainedTokenizerFast)
|
5 |
+
|
6 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
7 |
NUM_SENTINEL_TOKENS: int = 100
|
8 |
|
9 |
+
|
10 |
def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
11 |
"""Adds sentinel tokens and padding token (if missing).
|
12 |
|
|
|
16 |
All added tokens are added as special tokens. No tokens are
|
17 |
added if sentinel tokens and padding token already exist.
|
18 |
"""
|
19 |
+
sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
|
20 |
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
21 |
if tokenizer.pad_token is None:
|
22 |
+
tokenizer.add_tokens("<pad>", special_tokens=True)
|
23 |
+
tokenizer.pad_token = "<pad>"
|
24 |
assert tokenizer.pad_token_id is not None
|
25 |
+
sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
|
26 |
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
27 |
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
28 |
|
29 |
+
|
30 |
class AutoTokenizerForMOD(AutoTokenizer):
|
31 |
"""AutoTokenizer + Adaptation for MOD.
|
32 |
|
|
|
43 |
"""See `AutoTokenizer.from_pretrained` docstring."""
|
44 |
tokenizer = super().from_pretrained(*args, **kwargs)
|
45 |
adapt_tokenizer_for_denoising(tokenizer)
|
46 |
+
return tokenizer
|
model/llava/model/mpt/attention.py
CHANGED
@@ -2,24 +2,45 @@
|
|
2 |
import math
|
3 |
import warnings
|
4 |
from typing import Optional
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
8 |
from torch import nn
|
|
|
9 |
from .norm import LPLayerNorm
|
10 |
|
11 |
-
|
|
|
|
|
|
|
12 |
if original_is_causal and num_query_tokens != num_key_tokens:
|
13 |
if num_query_tokens != 1:
|
14 |
-
raise NotImplementedError(
|
|
|
|
|
15 |
else:
|
16 |
return False
|
17 |
return original_is_causal
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
min_val = torch.finfo(q.dtype).min
|
24 |
(b, _, s_q, d) = q.shape
|
25 |
s_k = k.size(-1)
|
@@ -27,13 +48,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
|
|
27 |
softmax_scale = 1 / math.sqrt(d)
|
28 |
attn_weight = q.matmul(k) * softmax_scale
|
29 |
if attn_bias is not None:
|
30 |
-
if
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
attn_weight = attn_weight + attn_bias
|
33 |
if key_padding_mask is not None:
|
34 |
if attn_bias is not None:
|
35 |
-
warnings.warn(
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if is_causal:
|
38 |
s = max(s_q, s_k)
|
39 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
@@ -44,74 +79,146 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
|
|
44 |
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
45 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
46 |
if dropout_p:
|
47 |
-
attn_weight = torch.nn.functional.dropout(
|
|
|
|
|
48 |
out = attn_weight.matmul(v)
|
49 |
-
out = rearrange(out,
|
50 |
if needs_weights:
|
51 |
return (out, attn_weight)
|
52 |
return (out, None)
|
53 |
|
|
|
54 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
55 |
for tensor in tensors:
|
56 |
if tensor.dtype not in valid_dtypes:
|
57 |
-
raise TypeError(
|
|
|
|
|
58 |
if not tensor.is_cuda:
|
59 |
-
raise TypeError(
|
|
|
|
|
|
|
60 |
|
61 |
-
def flash_attn_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
try:
|
63 |
from flash_attn import bert_padding, flash_attn_interface
|
64 |
except:
|
65 |
-
raise RuntimeError(
|
66 |
check_valid_inputs(query, key, value)
|
67 |
if attn_bias is not None:
|
68 |
-
raise NotImplementedError(f
|
69 |
(batch_size, seqlen) = query.shape[:2]
|
70 |
if key_padding_mask is None:
|
71 |
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
72 |
-
query_padding_mask = key_padding_mask[:, -query.size(1):]
|
73 |
-
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
78 |
-
value_unpad = rearrange(
|
|
|
|
|
79 |
if multiquery:
|
80 |
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
81 |
-
value_unpad = value_unpad.expand(
|
|
|
|
|
82 |
dropout_p = dropout_p if training else 0.0
|
83 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
84 |
-
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
return (output, None)
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
try:
|
90 |
from flash_attn import flash_attn_triton
|
91 |
except:
|
92 |
-
raise RuntimeError(
|
|
|
|
|
93 |
check_valid_inputs(query, key, value)
|
94 |
if dropout_p:
|
95 |
-
raise NotImplementedError(f
|
96 |
if needs_weights:
|
97 |
-
raise NotImplementedError(f
|
98 |
if key_padding_mask is not None:
|
99 |
-
warnings.warn(
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
(b_size, s_k) = key_padding_mask.shape[:2]
|
101 |
if attn_bias is None:
|
102 |
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
103 |
-
attn_bias = attn_bias.masked_fill(
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
107 |
if multiquery:
|
108 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
109 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
110 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
111 |
-
attn_output = flash_attn_triton.flash_attn_func(
|
|
|
|
|
112 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
113 |
return (output, None)
|
114 |
|
|
|
115 |
class MultiheadAttention(nn.Module):
|
116 |
"""Multi-head self attention.
|
117 |
|
@@ -119,7 +226,18 @@ class MultiheadAttention(nn.Module):
|
|
119 |
additive bias.
|
120 |
"""
|
121 |
|
122 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
super().__init__()
|
124 |
self.attn_impl = attn_impl
|
125 |
self.clip_qkv = clip_qkv
|
@@ -137,21 +255,38 @@ class MultiheadAttention(nn.Module):
|
|
137 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
138 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
139 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
140 |
-
if self.attn_impl ==
|
141 |
self.attn_fn = flash_attn_fn
|
142 |
-
elif self.attn_impl ==
|
143 |
self.attn_fn = triton_flash_attn_fn
|
144 |
-
warnings.warn(
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
146 |
self.attn_fn = scaled_multihead_dot_product_attention
|
147 |
if torch.cuda.is_available():
|
148 |
-
warnings.warn(
|
|
|
|
|
|
|
|
|
149 |
else:
|
150 |
-
raise ValueError(f
|
151 |
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
152 |
self.out_proj._is_residual = True
|
153 |
|
154 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
qkv = self.Wqkv(x)
|
156 |
if self.clip_qkv:
|
157 |
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
@@ -167,10 +302,23 @@ class MultiheadAttention(nn.Module):
|
|
167 |
value = torch.cat([past_key_value[1], value], dim=1)
|
168 |
past_key_value = (key, value)
|
169 |
if attn_bias is not None:
|
170 |
-
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
171 |
-
(context, attn_weights) = self.attn_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
return (self.out_proj(context), attn_weights, past_key_value)
|
173 |
|
|
|
174 |
class MultiQueryAttention(nn.Module):
|
175 |
"""Multi-Query self attention.
|
176 |
|
@@ -178,7 +326,18 @@ class MultiQueryAttention(nn.Module):
|
|
178 |
additive bias.
|
179 |
"""
|
180 |
|
181 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
super().__init__()
|
183 |
self.attn_impl = attn_impl
|
184 |
self.clip_qkv = clip_qkv
|
@@ -197,25 +356,44 @@ class MultiQueryAttention(nn.Module):
|
|
197 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
198 |
self.q_ln = layernorm_class(d_model, device=device)
|
199 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
200 |
-
if self.attn_impl ==
|
201 |
self.attn_fn = flash_attn_fn
|
202 |
-
elif self.attn_impl ==
|
203 |
self.attn_fn = triton_flash_attn_fn
|
204 |
-
warnings.warn(
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
206 |
self.attn_fn = scaled_multihead_dot_product_attention
|
207 |
if torch.cuda.is_available():
|
208 |
-
warnings.warn(
|
|
|
|
|
|
|
|
|
209 |
else:
|
210 |
-
raise ValueError(f
|
211 |
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
212 |
self.out_proj._is_residual = True
|
213 |
|
214 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
qkv = self.Wqkv(x)
|
216 |
if self.clip_qkv:
|
217 |
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
218 |
-
(query, key, value) = qkv.split(
|
|
|
|
|
219 |
key_padding_mask = attention_mask
|
220 |
if self.qk_ln:
|
221 |
dtype = query.dtype
|
@@ -227,14 +405,30 @@ class MultiQueryAttention(nn.Module):
|
|
227 |
value = torch.cat([past_key_value[1], value], dim=1)
|
228 |
past_key_value = (key, value)
|
229 |
if attn_bias is not None:
|
230 |
-
attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
|
231 |
-
(context, attn_weights) = self.attn_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
return (self.out_proj(context), attn_weights, past_key_value)
|
233 |
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
236 |
return None
|
237 |
-
elif attn_impl in [
|
238 |
if alibi:
|
239 |
if (prefix_lm or not causal) or use_sequence_id:
|
240 |
return (1, n_heads, seq_len, seq_len)
|
@@ -243,18 +437,31 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
|
|
243 |
return (1, 1, seq_len, seq_len)
|
244 |
return None
|
245 |
else:
|
246 |
-
raise ValueError(f
|
247 |
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
250 |
return None
|
251 |
-
elif attn_impl in [
|
252 |
if alibi:
|
253 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
254 |
-
attn_bias = attn_bias.add(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
return attn_bias
|
256 |
else:
|
257 |
-
raise ValueError(f
|
|
|
258 |
|
259 |
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
260 |
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
@@ -265,12 +472,24 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
|
265 |
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
266 |
return slopes.view(1, n_heads, 1, 1)
|
267 |
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
270 |
if full:
|
271 |
-
alibi_bias = alibi_bias - torch.arange(
|
|
|
|
|
272 |
alibi_bias = alibi_bias.abs().mul(-1)
|
273 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
274 |
alibi_bias = alibi_bias * slopes
|
275 |
return alibi_bias.to(dtype=dtype)
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import math
|
3 |
import warnings
|
4 |
from typing import Optional
|
5 |
+
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
from einops import rearrange
|
9 |
from torch import nn
|
10 |
+
|
11 |
from .norm import LPLayerNorm
|
12 |
|
13 |
+
|
14 |
+
def _reset_is_causal(
|
15 |
+
num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
|
16 |
+
):
|
17 |
if original_is_causal and num_query_tokens != num_key_tokens:
|
18 |
if num_query_tokens != 1:
|
19 |
+
raise NotImplementedError(
|
20 |
+
"MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
|
21 |
+
)
|
22 |
else:
|
23 |
return False
|
24 |
return original_is_causal
|
25 |
|
26 |
+
|
27 |
+
def scaled_multihead_dot_product_attention(
|
28 |
+
query,
|
29 |
+
key,
|
30 |
+
value,
|
31 |
+
n_heads,
|
32 |
+
softmax_scale=None,
|
33 |
+
attn_bias=None,
|
34 |
+
key_padding_mask=None,
|
35 |
+
is_causal=False,
|
36 |
+
dropout_p=0.0,
|
37 |
+
training=False,
|
38 |
+
needs_weights=False,
|
39 |
+
multiquery=False,
|
40 |
+
):
|
41 |
+
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
42 |
+
k = rearrange(key, "b s (h d) -> b h d s", h=1 if multiquery else n_heads)
|
43 |
+
v = rearrange(value, "b s (h d) -> b h s d", h=1 if multiquery else n_heads)
|
44 |
min_val = torch.finfo(q.dtype).min
|
45 |
(b, _, s_q, d) = q.shape
|
46 |
s_k = k.size(-1)
|
|
|
48 |
softmax_scale = 1 / math.sqrt(d)
|
49 |
attn_weight = q.matmul(k) * softmax_scale
|
50 |
if attn_bias is not None:
|
51 |
+
if (
|
52 |
+
attn_bias.size(-1) != 1
|
53 |
+
and attn_bias.size(-1) != s_k
|
54 |
+
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
|
55 |
+
):
|
56 |
+
raise RuntimeError(
|
57 |
+
f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
|
58 |
+
)
|
59 |
attn_weight = attn_weight + attn_bias
|
60 |
if key_padding_mask is not None:
|
61 |
if attn_bias is not None:
|
62 |
+
warnings.warn(
|
63 |
+
"Propogating key_padding_mask to the attention module "
|
64 |
+
+ "and applying it within the attention module can cause "
|
65 |
+
+ "unneccessary computation/memory usage. Consider integrating "
|
66 |
+
+ "into attn_bias once and passing that to each attention "
|
67 |
+
+ "module instead."
|
68 |
+
)
|
69 |
+
attn_weight = attn_weight.masked_fill(
|
70 |
+
~key_padding_mask.view((b, 1, 1, s_k)), min_val
|
71 |
+
)
|
72 |
if is_causal:
|
73 |
s = max(s_q, s_k)
|
74 |
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
|
|
79 |
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
80 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
81 |
if dropout_p:
|
82 |
+
attn_weight = torch.nn.functional.dropout(
|
83 |
+
attn_weight, p=dropout_p, training=training, inplace=True
|
84 |
+
)
|
85 |
out = attn_weight.matmul(v)
|
86 |
+
out = rearrange(out, "b h s d -> b s (h d)")
|
87 |
if needs_weights:
|
88 |
return (out, attn_weight)
|
89 |
return (out, None)
|
90 |
|
91 |
+
|
92 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
93 |
for tensor in tensors:
|
94 |
if tensor.dtype not in valid_dtypes:
|
95 |
+
raise TypeError(
|
96 |
+
f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
|
97 |
+
)
|
98 |
if not tensor.is_cuda:
|
99 |
+
raise TypeError(
|
100 |
+
f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
|
101 |
+
)
|
102 |
+
|
103 |
|
104 |
+
def flash_attn_fn(
|
105 |
+
query,
|
106 |
+
key,
|
107 |
+
value,
|
108 |
+
n_heads,
|
109 |
+
softmax_scale=None,
|
110 |
+
attn_bias=None,
|
111 |
+
key_padding_mask=None,
|
112 |
+
is_causal=False,
|
113 |
+
dropout_p=0.0,
|
114 |
+
training=False,
|
115 |
+
needs_weights=False,
|
116 |
+
multiquery=False,
|
117 |
+
):
|
118 |
try:
|
119 |
from flash_attn import bert_padding, flash_attn_interface
|
120 |
except:
|
121 |
+
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
122 |
check_valid_inputs(query, key, value)
|
123 |
if attn_bias is not None:
|
124 |
+
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
125 |
(batch_size, seqlen) = query.shape[:2]
|
126 |
if key_padding_mask is None:
|
127 |
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
128 |
+
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
129 |
+
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
|
130 |
+
query, query_padding_mask
|
131 |
+
)
|
132 |
+
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
133 |
+
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
|
134 |
+
key, key_padding_mask
|
135 |
+
)
|
136 |
+
key_unpad = rearrange(
|
137 |
+
key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
138 |
+
)
|
139 |
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
140 |
+
value_unpad = rearrange(
|
141 |
+
value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
|
142 |
+
)
|
143 |
if multiquery:
|
144 |
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
145 |
+
value_unpad = value_unpad.expand(
|
146 |
+
value_unpad.size(0), n_heads, value_unpad.size(-1)
|
147 |
+
)
|
148 |
dropout_p = dropout_p if training else 0.0
|
149 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
150 |
+
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
151 |
+
query_unpad,
|
152 |
+
key_unpad,
|
153 |
+
value_unpad,
|
154 |
+
cu_seqlens_q,
|
155 |
+
cu_seqlens_k,
|
156 |
+
max_seqlen_q,
|
157 |
+
max_seqlen_k,
|
158 |
+
dropout_p,
|
159 |
+
softmax_scale=softmax_scale,
|
160 |
+
causal=reset_is_causal,
|
161 |
+
return_attn_probs=needs_weights,
|
162 |
+
)
|
163 |
+
output = bert_padding.pad_input(
|
164 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
|
165 |
+
)
|
166 |
return (output, None)
|
167 |
|
168 |
+
|
169 |
+
def triton_flash_attn_fn(
|
170 |
+
query,
|
171 |
+
key,
|
172 |
+
value,
|
173 |
+
n_heads,
|
174 |
+
softmax_scale=None,
|
175 |
+
attn_bias=None,
|
176 |
+
key_padding_mask=None,
|
177 |
+
is_causal=False,
|
178 |
+
dropout_p=0.0,
|
179 |
+
training=False,
|
180 |
+
needs_weights=False,
|
181 |
+
multiquery=False,
|
182 |
+
):
|
183 |
try:
|
184 |
from flash_attn import flash_attn_triton
|
185 |
except:
|
186 |
+
raise RuntimeError(
|
187 |
+
"Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202"
|
188 |
+
)
|
189 |
check_valid_inputs(query, key, value)
|
190 |
if dropout_p:
|
191 |
+
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
192 |
if needs_weights:
|
193 |
+
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
194 |
if key_padding_mask is not None:
|
195 |
+
warnings.warn(
|
196 |
+
"Propagating key_padding_mask to the attention module "
|
197 |
+
+ "and applying it within the attention module can cause "
|
198 |
+
+ "unnecessary computation/memory usage. Consider integrating "
|
199 |
+
+ "into attn_bias once and passing that to each attention "
|
200 |
+
+ "module instead."
|
201 |
+
)
|
202 |
(b_size, s_k) = key_padding_mask.shape[:2]
|
203 |
if attn_bias is None:
|
204 |
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
205 |
+
attn_bias = attn_bias.masked_fill(
|
206 |
+
~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
|
207 |
+
)
|
208 |
+
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
209 |
+
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
210 |
+
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
211 |
if multiquery:
|
212 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
213 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
214 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
215 |
+
attn_output = flash_attn_triton.flash_attn_func(
|
216 |
+
query, key, value, attn_bias, reset_is_causal, softmax_scale
|
217 |
+
)
|
218 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
219 |
return (output, None)
|
220 |
|
221 |
+
|
222 |
class MultiheadAttention(nn.Module):
|
223 |
"""Multi-head self attention.
|
224 |
|
|
|
226 |
additive bias.
|
227 |
"""
|
228 |
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
d_model: int,
|
232 |
+
n_heads: int,
|
233 |
+
attn_impl: str = "triton",
|
234 |
+
clip_qkv: Optional[float] = None,
|
235 |
+
qk_ln: bool = False,
|
236 |
+
softmax_scale: Optional[float] = None,
|
237 |
+
attn_pdrop: float = 0.0,
|
238 |
+
low_precision_layernorm: bool = False,
|
239 |
+
device: Optional[str] = None,
|
240 |
+
):
|
241 |
super().__init__()
|
242 |
self.attn_impl = attn_impl
|
243 |
self.clip_qkv = clip_qkv
|
|
|
255 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
256 |
self.q_ln = layernorm_class(self.d_model, device=device)
|
257 |
self.k_ln = layernorm_class(self.d_model, device=device)
|
258 |
+
if self.attn_impl == "flash":
|
259 |
self.attn_fn = flash_attn_fn
|
260 |
+
elif self.attn_impl == "triton":
|
261 |
self.attn_fn = triton_flash_attn_fn
|
262 |
+
warnings.warn(
|
263 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
264 |
+
+ "it uses more memory. When training larger models this can trigger "
|
265 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
266 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
267 |
+
)
|
268 |
+
elif self.attn_impl == "torch":
|
269 |
self.attn_fn = scaled_multihead_dot_product_attention
|
270 |
if torch.cuda.is_available():
|
271 |
+
warnings.warn(
|
272 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
273 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
274 |
+
+ "we recommend using `attn_impl: triton`."
|
275 |
+
)
|
276 |
else:
|
277 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
278 |
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
279 |
self.out_proj._is_residual = True
|
280 |
|
281 |
+
def forward(
|
282 |
+
self,
|
283 |
+
x,
|
284 |
+
past_key_value=None,
|
285 |
+
attn_bias=None,
|
286 |
+
attention_mask=None,
|
287 |
+
is_causal=True,
|
288 |
+
needs_weights=False,
|
289 |
+
):
|
290 |
qkv = self.Wqkv(x)
|
291 |
if self.clip_qkv:
|
292 |
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
|
302 |
value = torch.cat([past_key_value[1], value], dim=1)
|
303 |
past_key_value = (key, value)
|
304 |
if attn_bias is not None:
|
305 |
+
attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
|
306 |
+
(context, attn_weights) = self.attn_fn(
|
307 |
+
query,
|
308 |
+
key,
|
309 |
+
value,
|
310 |
+
self.n_heads,
|
311 |
+
softmax_scale=self.softmax_scale,
|
312 |
+
attn_bias=attn_bias,
|
313 |
+
key_padding_mask=key_padding_mask,
|
314 |
+
is_causal=is_causal,
|
315 |
+
dropout_p=self.attn_dropout_p,
|
316 |
+
training=self.training,
|
317 |
+
needs_weights=needs_weights,
|
318 |
+
)
|
319 |
return (self.out_proj(context), attn_weights, past_key_value)
|
320 |
|
321 |
+
|
322 |
class MultiQueryAttention(nn.Module):
|
323 |
"""Multi-Query self attention.
|
324 |
|
|
|
326 |
additive bias.
|
327 |
"""
|
328 |
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
d_model: int,
|
332 |
+
n_heads: int,
|
333 |
+
attn_impl: str = "triton",
|
334 |
+
clip_qkv: Optional[float] = None,
|
335 |
+
qk_ln: bool = False,
|
336 |
+
softmax_scale: Optional[float] = None,
|
337 |
+
attn_pdrop: float = 0.0,
|
338 |
+
low_precision_layernorm: bool = False,
|
339 |
+
device: Optional[str] = None,
|
340 |
+
):
|
341 |
super().__init__()
|
342 |
self.attn_impl = attn_impl
|
343 |
self.clip_qkv = clip_qkv
|
|
|
356 |
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
357 |
self.q_ln = layernorm_class(d_model, device=device)
|
358 |
self.k_ln = layernorm_class(self.head_dim, device=device)
|
359 |
+
if self.attn_impl == "flash":
|
360 |
self.attn_fn = flash_attn_fn
|
361 |
+
elif self.attn_impl == "triton":
|
362 |
self.attn_fn = triton_flash_attn_fn
|
363 |
+
warnings.warn(
|
364 |
+
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
365 |
+
+ "it uses more memory. When training larger models this can trigger "
|
366 |
+
+ "alloc retries which hurts performance. If encountered, we recommend "
|
367 |
+
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
368 |
+
)
|
369 |
+
elif self.attn_impl == "torch":
|
370 |
self.attn_fn = scaled_multihead_dot_product_attention
|
371 |
if torch.cuda.is_available():
|
372 |
+
warnings.warn(
|
373 |
+
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
374 |
+
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
375 |
+
+ "we recommend using `attn_impl: triton`."
|
376 |
+
)
|
377 |
else:
|
378 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
379 |
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
380 |
self.out_proj._is_residual = True
|
381 |
|
382 |
+
def forward(
|
383 |
+
self,
|
384 |
+
x,
|
385 |
+
past_key_value=None,
|
386 |
+
attn_bias=None,
|
387 |
+
attention_mask=None,
|
388 |
+
is_causal=True,
|
389 |
+
needs_weights=False,
|
390 |
+
):
|
391 |
qkv = self.Wqkv(x)
|
392 |
if self.clip_qkv:
|
393 |
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
394 |
+
(query, key, value) = qkv.split(
|
395 |
+
[self.d_model, self.head_dim, self.head_dim], dim=2
|
396 |
+
)
|
397 |
key_padding_mask = attention_mask
|
398 |
if self.qk_ln:
|
399 |
dtype = query.dtype
|
|
|
405 |
value = torch.cat([past_key_value[1], value], dim=1)
|
406 |
past_key_value = (key, value)
|
407 |
if attn_bias is not None:
|
408 |
+
attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
|
409 |
+
(context, attn_weights) = self.attn_fn(
|
410 |
+
query,
|
411 |
+
key,
|
412 |
+
value,
|
413 |
+
self.n_heads,
|
414 |
+
softmax_scale=self.softmax_scale,
|
415 |
+
attn_bias=attn_bias,
|
416 |
+
key_padding_mask=key_padding_mask,
|
417 |
+
is_causal=is_causal,
|
418 |
+
dropout_p=self.attn_dropout_p,
|
419 |
+
training=self.training,
|
420 |
+
needs_weights=needs_weights,
|
421 |
+
multiquery=True,
|
422 |
+
)
|
423 |
return (self.out_proj(context), attn_weights, past_key_value)
|
424 |
|
425 |
+
|
426 |
+
def attn_bias_shape(
|
427 |
+
attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
|
428 |
+
):
|
429 |
+
if attn_impl == "flash":
|
430 |
return None
|
431 |
+
elif attn_impl in ["torch", "triton"]:
|
432 |
if alibi:
|
433 |
if (prefix_lm or not causal) or use_sequence_id:
|
434 |
return (1, n_heads, seq_len, seq_len)
|
|
|
437 |
return (1, 1, seq_len, seq_len)
|
438 |
return None
|
439 |
else:
|
440 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
441 |
|
442 |
+
|
443 |
+
def build_attn_bias(
|
444 |
+
attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
|
445 |
+
):
|
446 |
+
if attn_impl == "flash":
|
447 |
return None
|
448 |
+
elif attn_impl in ["torch", "triton"]:
|
449 |
if alibi:
|
450 |
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
451 |
+
attn_bias = attn_bias.add(
|
452 |
+
build_alibi_bias(
|
453 |
+
n_heads,
|
454 |
+
seq_len,
|
455 |
+
full=not causal,
|
456 |
+
alibi_bias_max=alibi_bias_max,
|
457 |
+
device=device,
|
458 |
+
dtype=dtype,
|
459 |
+
)
|
460 |
+
)
|
461 |
return attn_bias
|
462 |
else:
|
463 |
+
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
464 |
+
|
465 |
|
466 |
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
467 |
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
|
|
472 |
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
473 |
return slopes.view(1, n_heads, 1, 1)
|
474 |
|
475 |
+
|
476 |
+
def build_alibi_bias(
|
477 |
+
n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
|
478 |
+
):
|
479 |
+
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
|
480 |
+
1, 1, 1, seq_len
|
481 |
+
)
|
482 |
if full:
|
483 |
+
alibi_bias = alibi_bias - torch.arange(
|
484 |
+
1 - seq_len, 1, dtype=torch.int32, device=device
|
485 |
+
).view(1, 1, seq_len, 1)
|
486 |
alibi_bias = alibi_bias.abs().mul(-1)
|
487 |
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
488 |
alibi_bias = alibi_bias * slopes
|
489 |
return alibi_bias.to(dtype=dtype)
|
490 |
+
|
491 |
+
|
492 |
+
ATTN_CLASS_REGISTRY = {
|
493 |
+
"multihead_attention": MultiheadAttention,
|
494 |
+
"multiquery_attention": MultiQueryAttention,
|
495 |
+
}
|
model/llava/model/mpt/blocks.py
CHANGED
@@ -1,41 +1,90 @@
|
|
1 |
"""GPT Blocks used for the GPT Model."""
|
2 |
from typing import Dict, Optional, Tuple
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
|
|
5 |
from .attention import ATTN_CLASS_REGISTRY
|
6 |
from .norm import NORM_CLASS_REGISTRY
|
7 |
|
8 |
-
class MPTMLP(nn.Module):
|
9 |
|
10 |
-
|
|
|
|
|
|
|
11 |
super().__init__()
|
12 |
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
13 |
-
self.act = nn.GELU(approximate=
|
14 |
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
15 |
self.down_proj._is_residual = True
|
16 |
|
17 |
def forward(self, x):
|
18 |
return self.down_proj(self.act(self.up_proj(x)))
|
19 |
|
20 |
-
class MPTBlock(nn.Module):
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
del kwargs
|
24 |
super().__init__()
|
25 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
26 |
-
attn_class = ATTN_CLASS_REGISTRY[attn_config[
|
27 |
self.norm_1 = norm_class(d_model, device=device)
|
28 |
-
self.attn = attn_class(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.norm_2 = norm_class(d_model, device=device)
|
30 |
-
self.ffn = MPTMLP(
|
|
|
|
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
32 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
33 |
|
34 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
a = self.norm_1(x)
|
36 |
-
(b, _, past_key_value) = self.attn(
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
x = x + self.resid_attn_dropout(b)
|
38 |
m = self.norm_2(x)
|
39 |
n = self.ffn(m)
|
40 |
x = x + self.resid_ffn_dropout(n)
|
41 |
-
return (x, past_key_value)
|
|
|
1 |
"""GPT Blocks used for the GPT Model."""
|
2 |
from typing import Dict, Optional, Tuple
|
3 |
+
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
+
|
7 |
from .attention import ATTN_CLASS_REGISTRY
|
8 |
from .norm import NORM_CLASS_REGISTRY
|
9 |
|
|
|
10 |
|
11 |
+
class MPTMLP(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, d_model: int, expansion_ratio: int, device: Optional[str] = None
|
14 |
+
):
|
15 |
super().__init__()
|
16 |
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
17 |
+
self.act = nn.GELU(approximate="none")
|
18 |
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
19 |
self.down_proj._is_residual = True
|
20 |
|
21 |
def forward(self, x):
|
22 |
return self.down_proj(self.act(self.up_proj(x)))
|
23 |
|
|
|
24 |
|
25 |
+
class MPTBlock(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
d_model: int,
|
29 |
+
n_heads: int,
|
30 |
+
expansion_ratio: int,
|
31 |
+
attn_config: Dict = {
|
32 |
+
"attn_type": "multihead_attention",
|
33 |
+
"attn_pdrop": 0.0,
|
34 |
+
"attn_impl": "triton",
|
35 |
+
"qk_ln": False,
|
36 |
+
"clip_qkv": None,
|
37 |
+
"softmax_scale": None,
|
38 |
+
"prefix_lm": False,
|
39 |
+
"attn_uses_sequence_id": False,
|
40 |
+
"alibi": False,
|
41 |
+
"alibi_bias_max": 8,
|
42 |
+
},
|
43 |
+
resid_pdrop: float = 0.0,
|
44 |
+
norm_type: str = "low_precision_layernorm",
|
45 |
+
device: Optional[str] = None,
|
46 |
+
**kwargs
|
47 |
+
):
|
48 |
del kwargs
|
49 |
super().__init__()
|
50 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
51 |
+
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
52 |
self.norm_1 = norm_class(d_model, device=device)
|
53 |
+
self.attn = attn_class(
|
54 |
+
attn_impl=attn_config["attn_impl"],
|
55 |
+
clip_qkv=attn_config["clip_qkv"],
|
56 |
+
qk_ln=attn_config["qk_ln"],
|
57 |
+
softmax_scale=attn_config["softmax_scale"],
|
58 |
+
attn_pdrop=attn_config["attn_pdrop"],
|
59 |
+
d_model=d_model,
|
60 |
+
n_heads=n_heads,
|
61 |
+
device=device,
|
62 |
+
)
|
63 |
self.norm_2 = norm_class(d_model, device=device)
|
64 |
+
self.ffn = MPTMLP(
|
65 |
+
d_model=d_model, expansion_ratio=expansion_ratio, device=device
|
66 |
+
)
|
67 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
68 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
69 |
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
x: torch.Tensor,
|
73 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
74 |
+
attn_bias: Optional[torch.Tensor] = None,
|
75 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
76 |
+
is_causal: bool = True,
|
77 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
78 |
a = self.norm_1(x)
|
79 |
+
(b, _, past_key_value) = self.attn(
|
80 |
+
a,
|
81 |
+
past_key_value=past_key_value,
|
82 |
+
attn_bias=attn_bias,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
is_causal=is_causal,
|
85 |
+
)
|
86 |
x = x + self.resid_attn_dropout(b)
|
87 |
m = self.norm_2(x)
|
88 |
n = self.ffn(m)
|
89 |
x = x + self.resid_ffn_dropout(n)
|
90 |
+
return (x, past_key_value)
|
model/llava/model/mpt/configuration_mpt.py
CHANGED
@@ -1,13 +1,52 @@
|
|
1 |
"""A HuggingFace-style model configuration."""
|
2 |
from typing import Dict, Optional, Union
|
|
|
3 |
from transformers import PretrainedConfig
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class MPTConfig(PretrainedConfig):
|
8 |
-
model_type =
|
9 |
|
10 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""The MPT configuration class.
|
12 |
|
13 |
Args:
|
@@ -80,39 +119,76 @@ class MPTConfig(PretrainedConfig):
|
|
80 |
self.norm_type = norm_type
|
81 |
self.use_cache = use_cache
|
82 |
self.init_config = init_config
|
83 |
-
if
|
84 |
-
del kwargs[
|
85 |
-
if
|
86 |
-
del kwargs[
|
87 |
super().__init__(**kwargs)
|
88 |
self._validate_config()
|
89 |
|
90 |
def _set_config_defaults(self, config, config_defaults):
|
91 |
-
for
|
92 |
if k not in config:
|
93 |
config[k] = v
|
94 |
return config
|
95 |
|
96 |
def _validate_config(self):
|
97 |
-
self.attn_config = self._set_config_defaults(
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
if self.d_model % self.n_heads != 0:
|
100 |
-
raise ValueError(
|
101 |
-
if any(
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
105 |
-
if self.attn_config[
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
112 |
-
raise ValueError(
|
113 |
-
|
114 |
-
|
115 |
-
if self.
|
116 |
-
raise ValueError(
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""A HuggingFace-style model configuration."""
|
2 |
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
from transformers import PretrainedConfig
|
5 |
+
|
6 |
+
attn_config_defaults: Dict = {
|
7 |
+
"attn_type": "multihead_attention",
|
8 |
+
"attn_pdrop": 0.0,
|
9 |
+
"attn_impl": "triton",
|
10 |
+
"qk_ln": False,
|
11 |
+
"clip_qkv": None,
|
12 |
+
"softmax_scale": None,
|
13 |
+
"prefix_lm": False,
|
14 |
+
"attn_uses_sequence_id": False,
|
15 |
+
"alibi": False,
|
16 |
+
"alibi_bias_max": 8,
|
17 |
+
}
|
18 |
+
init_config_defaults: Dict = {
|
19 |
+
"name": "kaiming_normal_",
|
20 |
+
"fan_mode": "fan_in",
|
21 |
+
"init_nonlinearity": "relu",
|
22 |
+
}
|
23 |
+
|
24 |
|
25 |
class MPTConfig(PretrainedConfig):
|
26 |
+
model_type = "mpt"
|
27 |
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
d_model: int = 2048,
|
31 |
+
n_heads: int = 16,
|
32 |
+
n_layers: int = 24,
|
33 |
+
expansion_ratio: int = 4,
|
34 |
+
max_seq_len: int = 2048,
|
35 |
+
vocab_size: int = 50368,
|
36 |
+
resid_pdrop: float = 0.0,
|
37 |
+
emb_pdrop: float = 0.0,
|
38 |
+
learned_pos_emb: bool = True,
|
39 |
+
attn_config: Dict = attn_config_defaults,
|
40 |
+
init_device: str = "cpu",
|
41 |
+
logit_scale: Optional[Union[float, str]] = None,
|
42 |
+
no_bias: bool = False,
|
43 |
+
verbose: int = 0,
|
44 |
+
embedding_fraction: float = 1.0,
|
45 |
+
norm_type: str = "low_precision_layernorm",
|
46 |
+
use_cache: bool = False,
|
47 |
+
init_config: Dict = init_config_defaults,
|
48 |
+
**kwargs,
|
49 |
+
):
|
50 |
"""The MPT configuration class.
|
51 |
|
52 |
Args:
|
|
|
119 |
self.norm_type = norm_type
|
120 |
self.use_cache = use_cache
|
121 |
self.init_config = init_config
|
122 |
+
if "name" in kwargs:
|
123 |
+
del kwargs["name"]
|
124 |
+
if "loss_fn" in kwargs:
|
125 |
+
del kwargs["loss_fn"]
|
126 |
super().__init__(**kwargs)
|
127 |
self._validate_config()
|
128 |
|
129 |
def _set_config_defaults(self, config, config_defaults):
|
130 |
+
for k, v in config_defaults.items():
|
131 |
if k not in config:
|
132 |
config[k] = v
|
133 |
return config
|
134 |
|
135 |
def _validate_config(self):
|
136 |
+
self.attn_config = self._set_config_defaults(
|
137 |
+
self.attn_config, attn_config_defaults
|
138 |
+
)
|
139 |
+
self.init_config = self._set_config_defaults(
|
140 |
+
self.init_config, init_config_defaults
|
141 |
+
)
|
142 |
if self.d_model % self.n_heads != 0:
|
143 |
+
raise ValueError("d_model must be divisible by n_heads")
|
144 |
+
if any(
|
145 |
+
(
|
146 |
+
prob < 0 or prob > 1
|
147 |
+
for prob in [
|
148 |
+
self.attn_config["attn_pdrop"],
|
149 |
+
self.resid_pdrop,
|
150 |
+
self.emb_pdrop,
|
151 |
+
]
|
152 |
+
)
|
153 |
+
):
|
154 |
+
raise ValueError(
|
155 |
+
"self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
|
156 |
+
)
|
157 |
+
if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
|
158 |
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
159 |
+
if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
|
160 |
+
"torch",
|
161 |
+
"triton",
|
162 |
+
]:
|
163 |
+
raise NotImplementedError(
|
164 |
+
"prefix_lm only implemented with torch and triton attention."
|
165 |
+
)
|
166 |
+
if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
|
167 |
+
"torch",
|
168 |
+
"triton",
|
169 |
+
]:
|
170 |
+
raise NotImplementedError(
|
171 |
+
"alibi only implemented with torch and triton attention."
|
172 |
+
)
|
173 |
+
if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
|
174 |
+
"attn_impl"
|
175 |
+
] not in ["torch", "triton"]:
|
176 |
+
raise NotImplementedError(
|
177 |
+
"attn_uses_sequence_id only implemented with torch and triton attention."
|
178 |
+
)
|
179 |
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
180 |
+
raise ValueError(
|
181 |
+
"model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
|
182 |
+
)
|
183 |
+
if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
|
184 |
+
raise ValueError(
|
185 |
+
f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
186 |
+
)
|
187 |
+
if self.init_config.get("name", None) is None:
|
188 |
+
raise ValueError(
|
189 |
+
f"self.init_config={self.init_config!r} 'name' needs to be set."
|
190 |
+
)
|
191 |
+
if not self.learned_pos_emb and (not self.attn_config["alibi"]):
|
192 |
+
raise ValueError(
|
193 |
+
f"Positional information must be provided to the model using either learned_pos_emb or alibi."
|
194 |
+
)
|
model/llava/model/mpt/hf_prefixlm_converter.py
CHANGED
@@ -10,21 +10,37 @@ import math
|
|
10 |
import warnings
|
11 |
from types import MethodType
|
12 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
13 |
import torch
|
14 |
-
from transformers.models.bloom.modeling_bloom import
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
from transformers.models.bloom.modeling_bloom import logging
|
18 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
19 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
20 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
21 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
22 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
23 |
-
from transformers.models.opt.modeling_opt import
|
24 |
-
|
|
|
|
|
|
|
25 |
logger = logging.get_logger(__name__)
|
26 |
-
_SUPPORTED_GPT_MODELS = (
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
30 |
"""Converts a GPT-style Causal LM to a Prefix LM.
|
@@ -37,10 +53,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
37 |
|
38 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
39 |
"""
|
40 |
-
if hasattr(model,
|
41 |
return model
|
42 |
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
43 |
-
assert
|
|
|
|
|
44 |
|
45 |
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
46 |
"""Helper that gets a list of the model's attention modules.
|
@@ -56,7 +74,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
56 |
blocks = model.transformer.h
|
57 |
for block in blocks:
|
58 |
if isinstance(model, GPTNeoForCausalLM):
|
59 |
-
if block.attn.attention_type !=
|
60 |
continue
|
61 |
attn_module = block.attn.attention
|
62 |
elif isinstance(model, GPTNeoXForCausalLM):
|
@@ -65,17 +83,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
65 |
attn_module = block.attn
|
66 |
attn_modules.append(attn_module)
|
67 |
return attn_modules
|
68 |
-
setattr(model, '_original_forward', getattr(model, 'forward'))
|
69 |
-
setattr(model, '_original_generate', getattr(model, 'generate'))
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
"""Wraps original forward to enable PrefixLM attention."""
|
73 |
|
74 |
def call_og_forward():
|
75 |
if isinstance(self, GPTNeoXForCausalLM):
|
76 |
-
return self._original_forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
-
return self._original_forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if bidirectional_mask is None:
|
80 |
return call_og_forward()
|
81 |
assert isinstance(bidirectional_mask, torch.Tensor)
|
@@ -83,14 +142,23 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
83 |
(b, s) = bidirectional_mask.shape
|
84 |
max_length = attn_modules[0].bias.shape[-1]
|
85 |
if s > max_length:
|
86 |
-
raise ValueError(
|
|
|
|
|
|
|
87 |
assert s <= max_length
|
88 |
if s < max_length:
|
89 |
-
pad = torch.zeros(
|
|
|
|
|
|
|
|
|
90 |
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
91 |
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
92 |
for attn_module in attn_modules:
|
93 |
-
attn_module.bias.data = torch.logical_or(
|
|
|
|
|
94 |
output = call_og_forward()
|
95 |
for attn_module in attn_modules:
|
96 |
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
@@ -105,11 +173,13 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
105 |
for attn_module in attn_modules:
|
106 |
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
107 |
return output
|
108 |
-
|
109 |
-
setattr(model,
|
110 |
-
setattr(model,
|
|
|
111 |
return model
|
112 |
|
|
|
113 |
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
114 |
"""Converts a BLOOM Causal LM to a Prefix LM.
|
115 |
|
@@ -118,62 +188,137 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
118 |
|
119 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
120 |
"""
|
121 |
-
if hasattr(model,
|
122 |
return model
|
123 |
assert isinstance(model, BloomForCausalLM)
|
124 |
-
assert
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
combined_attention_mask = None
|
128 |
device = attention_mask.device
|
129 |
(_, src_length) = input_shape
|
130 |
if src_length > 1:
|
131 |
-
combined_attention_mask = _make_causal_mask_bloom(
|
|
|
|
|
|
|
|
|
132 |
if bidirectional_mask is not None:
|
133 |
assert attention_mask.shape == bidirectional_mask.shape
|
134 |
-
expanded_bidirectional_mask = _expand_mask_bloom(
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
|
137 |
-
combined_attention_mask =
|
|
|
|
|
|
|
|
|
138 |
return combined_attention_mask
|
139 |
|
140 |
-
def _build_alibi_tensor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
num_heads = self.config.n_head
|
142 |
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
143 |
-
base = torch.tensor(
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
slopes = torch.pow(base, powers)
|
146 |
if closest_power_of_2 != num_heads:
|
147 |
-
extra_base = torch.tensor(
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
151 |
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
|
152 |
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
|
153 |
diffs = qa - ka + key_length - query_length
|
154 |
diffs = -diffs.abs()
|
155 |
-
alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
|
156 |
-
|
|
|
|
|
|
|
|
|
157 |
return alibi.to(dtype)
|
|
|
158 |
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
159 |
|
160 |
-
def forward(
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if len(deprecated_arguments) > 0:
|
164 |
-
raise ValueError(f
|
165 |
-
output_attentions =
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
168 |
-
return_dict =
|
|
|
|
|
169 |
if input_ids is not None and inputs_embeds is not None:
|
170 |
-
raise ValueError(
|
|
|
|
|
171 |
elif input_ids is not None:
|
172 |
(batch_size, seq_length) = input_ids.shape
|
173 |
elif inputs_embeds is not None:
|
174 |
(batch_size, seq_length, _) = inputs_embeds.shape
|
175 |
else:
|
176 |
-
raise ValueError(
|
177 |
if past_key_values is None:
|
178 |
past_key_values = tuple([None] * len(self.h))
|
179 |
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
@@ -190,28 +335,62 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
190 |
past_key_values_length = tmp.shape[2]
|
191 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
192 |
if attention_mask is None:
|
193 |
-
attention_mask = torch.ones(
|
|
|
|
|
194 |
else:
|
195 |
attention_mask = attention_mask.to(hidden_states.device)
|
196 |
-
alibi = self._build_alibi_tensor(
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
if output_hidden_states:
|
200 |
hst = (hidden_states,)
|
201 |
all_hidden_states = all_hidden_states + hst
|
202 |
if self.gradient_checkpointing and self.training:
|
203 |
if use_cache:
|
204 |
-
logger.warning(
|
|
|
|
|
205 |
use_cache = False
|
206 |
|
207 |
def create_custom_forward(module):
|
208 |
-
|
209 |
def custom_forward(*inputs):
|
210 |
-
return module(
|
|
|
|
|
|
|
|
|
|
|
211 |
return custom_forward
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
else:
|
214 |
-
outputs = block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
hidden_states = outputs[0]
|
216 |
if use_cache is True:
|
217 |
presents = presents + (outputs[1],)
|
@@ -223,21 +402,77 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
223 |
hst = (hidden_states,)
|
224 |
all_hidden_states = all_hidden_states + hst
|
225 |
if not return_dict:
|
226 |
-
return tuple(
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
232 |
|
233 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
"""Replacement forward method for BloomCausalLM."""
|
235 |
-
if deprecated_arguments.pop(
|
236 |
-
warnings.warn(
|
|
|
|
|
|
|
|
|
237 |
if len(deprecated_arguments) > 0:
|
238 |
-
raise ValueError(f
|
239 |
-
return_dict =
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
hidden_states = transformer_outputs[0]
|
242 |
lm_logits = self.lm_head(hidden_states)
|
243 |
loss = None
|
@@ -246,13 +481,28 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
246 |
shift_labels = labels[..., 1:].contiguous()
|
247 |
(batch_size, seq_length, vocab_size) = shift_logits.shape
|
248 |
loss_fct = CrossEntropyLoss()
|
249 |
-
loss = loss_fct(
|
|
|
|
|
|
|
250 |
if not return_dict:
|
251 |
output = (lm_logits,) + transformer_outputs[1:]
|
252 |
return (loss,) + output if loss is not None else output
|
253 |
-
return CausalLMOutputWithCrossAttentions(
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
if past:
|
257 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
258 |
bidirectional_mask = None
|
@@ -260,12 +510,24 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
260 |
past = self._convert_to_bloom_cache(past)
|
261 |
else:
|
262 |
bidirectional_mask = torch.ones_like(input_ids)
|
263 |
-
return {
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
return model
|
268 |
|
|
|
269 |
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
270 |
"""Converts an OPT Causal LM to a Prefix LM.
|
271 |
|
@@ -274,36 +536,89 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
|
|
274 |
|
275 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
276 |
"""
|
277 |
-
if hasattr(model,
|
278 |
return model
|
279 |
assert isinstance(model, OPTForCausalLM)
|
280 |
-
assert
|
281 |
-
|
282 |
-
|
|
|
|
|
283 |
model.model.decoder.bidirectional_mask = None
|
284 |
|
285 |
-
def _prepare_decoder_attention_mask(
|
|
|
|
|
286 |
combined_attention_mask = None
|
287 |
if input_shape[-1] > 1:
|
288 |
-
if self.bidirectional_mask ==
|
289 |
(bsz, src_length) = input_shape
|
290 |
-
combined_attention_mask = torch.zeros(
|
|
|
|
|
|
|
|
|
291 |
else:
|
292 |
-
combined_attention_mask = _make_causal_mask_opt(
|
|
|
|
|
|
|
|
|
293 |
if self.bidirectional_mask is not None:
|
294 |
assert attention_mask.shape == self.bidirectional_mask.shape
|
295 |
-
expanded_bidirectional_mask = _expand_mask_opt(
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
if attention_mask is not None:
|
298 |
-
expanded_attn_mask = _expand_mask_opt(
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
return combined_attention_mask
|
301 |
-
setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
|
302 |
-
|
303 |
-
def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
def call_og_forward():
|
306 |
-
return self._original_forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
if bidirectional_mask is None:
|
308 |
return call_og_forward()
|
309 |
self.model.decoder.bidirectional_mask = bidirectional_mask
|
@@ -317,7 +632,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
|
|
317 |
|
318 |
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
|
319 |
"""Wraps original generate to enable PrefixLM-style attention."""
|
320 |
-
self.model.decoder.bidirectional_mask =
|
321 |
try:
|
322 |
output = self._original_generate(*args, **kwargs)
|
323 |
except:
|
@@ -325,12 +640,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
|
|
325 |
raise
|
326 |
self.model.decoder.bidirectional_mask = None
|
327 |
return output
|
328 |
-
|
329 |
-
setattr(model,
|
330 |
-
setattr(model,
|
|
|
331 |
return model
|
|
|
|
|
332 |
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
|
333 |
-
CAUSAL_LM_TYPES = Union[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
336 |
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
@@ -396,7 +722,12 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
|
|
396 |
elif isinstance(model, OPTForCausalLM):
|
397 |
return _convert_opt_causal_lm_to_prefix_lm(model)
|
398 |
else:
|
399 |
-
raise TypeError(
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
402 |
"""Attempts to add bidirectional_mask to batch if missing.
|
@@ -404,12 +735,16 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
|
404 |
Raises:
|
405 |
KeyError if bidirectional_mask is missing and can't be inferred
|
406 |
"""
|
407 |
-
if
|
408 |
-
if batch.get(
|
409 |
-
batch[
|
410 |
-
for
|
411 |
-
batch[
|
412 |
-
elif
|
413 |
-
batch[
|
|
|
|
|
414 |
else:
|
415 |
-
raise KeyError(
|
|
|
|
|
|
10 |
import warnings
|
11 |
from types import MethodType
|
12 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
13 |
+
|
14 |
import torch
|
15 |
+
from transformers.models.bloom.modeling_bloom import (
|
16 |
+
BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
|
17 |
+
CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
|
18 |
+
from transformers.models.bloom.modeling_bloom import \
|
19 |
+
_expand_mask as _expand_mask_bloom
|
20 |
+
from transformers.models.bloom.modeling_bloom import \
|
21 |
+
_make_causal_mask as _make_causal_mask_bloom
|
22 |
from transformers.models.bloom.modeling_bloom import logging
|
23 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
24 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
25 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
26 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
27 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
28 |
+
from transformers.models.opt.modeling_opt import \
|
29 |
+
_expand_mask as _expand_mask_opt
|
30 |
+
from transformers.models.opt.modeling_opt import \
|
31 |
+
_make_causal_mask as _make_causal_mask_opt
|
32 |
+
|
33 |
logger = logging.get_logger(__name__)
|
34 |
+
_SUPPORTED_GPT_MODELS = (
|
35 |
+
GPT2LMHeadModel,
|
36 |
+
GPTJForCausalLM,
|
37 |
+
GPTNeoForCausalLM,
|
38 |
+
GPTNeoXForCausalLM,
|
39 |
+
)
|
40 |
+
CAUSAL_GPT_TYPES = Union[
|
41 |
+
GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
|
42 |
+
]
|
43 |
+
|
44 |
|
45 |
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
46 |
"""Converts a GPT-style Causal LM to a Prefix LM.
|
|
|
53 |
|
54 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
55 |
"""
|
56 |
+
if hasattr(model, "_prefix_lm_converted"):
|
57 |
return model
|
58 |
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
59 |
+
assert (
|
60 |
+
model.config.add_cross_attention == False
|
61 |
+
), "Only supports GPT-style decoder-only models"
|
62 |
|
63 |
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
64 |
"""Helper that gets a list of the model's attention modules.
|
|
|
74 |
blocks = model.transformer.h
|
75 |
for block in blocks:
|
76 |
if isinstance(model, GPTNeoForCausalLM):
|
77 |
+
if block.attn.attention_type != "global":
|
78 |
continue
|
79 |
attn_module = block.attn.attention
|
80 |
elif isinstance(model, GPTNeoXForCausalLM):
|
|
|
83 |
attn_module = block.attn
|
84 |
attn_modules.append(attn_module)
|
85 |
return attn_modules
|
|
|
|
|
86 |
|
87 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
88 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
89 |
+
|
90 |
+
def forward(
|
91 |
+
self: CAUSAL_GPT_TYPES,
|
92 |
+
input_ids: Optional[torch.LongTensor] = None,
|
93 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
94 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
95 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
96 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
97 |
+
position_ids: Optional[torch.LongTensor] = None,
|
98 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
99 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
100 |
+
labels: Optional[torch.LongTensor] = None,
|
101 |
+
use_cache: Optional[bool] = None,
|
102 |
+
output_attentions: Optional[bool] = None,
|
103 |
+
output_hidden_states: Optional[bool] = None,
|
104 |
+
return_dict: Optional[bool] = None,
|
105 |
+
):
|
106 |
"""Wraps original forward to enable PrefixLM attention."""
|
107 |
|
108 |
def call_og_forward():
|
109 |
if isinstance(self, GPTNeoXForCausalLM):
|
110 |
+
return self._original_forward(
|
111 |
+
input_ids=input_ids,
|
112 |
+
past_key_values=past_key_values,
|
113 |
+
attention_mask=attention_mask,
|
114 |
+
head_mask=head_mask,
|
115 |
+
inputs_embeds=inputs_embeds,
|
116 |
+
labels=labels,
|
117 |
+
use_cache=use_cache,
|
118 |
+
output_attentions=output_attentions,
|
119 |
+
output_hidden_states=output_hidden_states,
|
120 |
+
return_dict=return_dict,
|
121 |
+
)
|
122 |
else:
|
123 |
+
return self._original_forward(
|
124 |
+
input_ids=input_ids,
|
125 |
+
past_key_values=past_key_values,
|
126 |
+
attention_mask=attention_mask,
|
127 |
+
token_type_ids=token_type_ids,
|
128 |
+
position_ids=position_ids,
|
129 |
+
head_mask=head_mask,
|
130 |
+
inputs_embeds=inputs_embeds,
|
131 |
+
labels=labels,
|
132 |
+
use_cache=use_cache,
|
133 |
+
output_attentions=output_attentions,
|
134 |
+
output_hidden_states=output_hidden_states,
|
135 |
+
return_dict=return_dict,
|
136 |
+
)
|
137 |
+
|
138 |
if bidirectional_mask is None:
|
139 |
return call_og_forward()
|
140 |
assert isinstance(bidirectional_mask, torch.Tensor)
|
|
|
142 |
(b, s) = bidirectional_mask.shape
|
143 |
max_length = attn_modules[0].bias.shape[-1]
|
144 |
if s > max_length:
|
145 |
+
raise ValueError(
|
146 |
+
f"bidirectional_mask sequence length (={s}) exceeds the "
|
147 |
+
+ f"max length allowed by the model ({max_length})."
|
148 |
+
)
|
149 |
assert s <= max_length
|
150 |
if s < max_length:
|
151 |
+
pad = torch.zeros(
|
152 |
+
(int(b), int(max_length - s)),
|
153 |
+
dtype=bidirectional_mask.dtype,
|
154 |
+
device=bidirectional_mask.device,
|
155 |
+
)
|
156 |
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
157 |
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
158 |
for attn_module in attn_modules:
|
159 |
+
attn_module.bias.data = torch.logical_or(
|
160 |
+
attn_module.bias.data, bidirectional
|
161 |
+
)
|
162 |
output = call_og_forward()
|
163 |
for attn_module in attn_modules:
|
164 |
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
|
|
173 |
for attn_module in attn_modules:
|
174 |
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
175 |
return output
|
176 |
+
|
177 |
+
setattr(model, "forward", MethodType(forward, model))
|
178 |
+
setattr(model, "generate", MethodType(generate, model))
|
179 |
+
setattr(model, "_prefix_lm_converted", True)
|
180 |
return model
|
181 |
|
182 |
+
|
183 |
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
184 |
"""Converts a BLOOM Causal LM to a Prefix LM.
|
185 |
|
|
|
188 |
|
189 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
190 |
"""
|
191 |
+
if hasattr(model, "_prefix_lm_converted"):
|
192 |
return model
|
193 |
assert isinstance(model, BloomForCausalLM)
|
194 |
+
assert (
|
195 |
+
model.config.add_cross_attention == False
|
196 |
+
), "Only supports BLOOM decoder-only models"
|
197 |
+
|
198 |
+
def _prepare_attn_mask(
|
199 |
+
self: BloomModel,
|
200 |
+
attention_mask: torch.Tensor,
|
201 |
+
bidirectional_mask: Optional[torch.Tensor],
|
202 |
+
input_shape: Tuple[int, int],
|
203 |
+
past_key_values_length: int,
|
204 |
+
) -> torch.BoolTensor:
|
205 |
combined_attention_mask = None
|
206 |
device = attention_mask.device
|
207 |
(_, src_length) = input_shape
|
208 |
if src_length > 1:
|
209 |
+
combined_attention_mask = _make_causal_mask_bloom(
|
210 |
+
input_shape,
|
211 |
+
device=device,
|
212 |
+
past_key_values_length=past_key_values_length,
|
213 |
+
)
|
214 |
if bidirectional_mask is not None:
|
215 |
assert attention_mask.shape == bidirectional_mask.shape
|
216 |
+
expanded_bidirectional_mask = _expand_mask_bloom(
|
217 |
+
bidirectional_mask, tgt_length=src_length
|
218 |
+
)
|
219 |
+
combined_attention_mask = torch.logical_and(
|
220 |
+
combined_attention_mask, expanded_bidirectional_mask
|
221 |
+
)
|
222 |
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
|
223 |
+
combined_attention_mask = (
|
224 |
+
expanded_attn_mask
|
225 |
+
if combined_attention_mask is None
|
226 |
+
else expanded_attn_mask | combined_attention_mask
|
227 |
+
)
|
228 |
return combined_attention_mask
|
229 |
|
230 |
+
def _build_alibi_tensor(
|
231 |
+
self: BloomModel,
|
232 |
+
batch_size: int,
|
233 |
+
query_length: int,
|
234 |
+
key_length: int,
|
235 |
+
dtype: torch.dtype,
|
236 |
+
device: torch.device,
|
237 |
+
) -> torch.Tensor:
|
238 |
num_heads = self.config.n_head
|
239 |
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
240 |
+
base = torch.tensor(
|
241 |
+
2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))),
|
242 |
+
device=device,
|
243 |
+
dtype=torch.float32,
|
244 |
+
)
|
245 |
+
powers = torch.arange(
|
246 |
+
1, 1 + closest_power_of_2, device=device, dtype=torch.int32
|
247 |
+
)
|
248 |
slopes = torch.pow(base, powers)
|
249 |
if closest_power_of_2 != num_heads:
|
250 |
+
extra_base = torch.tensor(
|
251 |
+
2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))),
|
252 |
+
device=device,
|
253 |
+
dtype=torch.float32,
|
254 |
+
)
|
255 |
+
num_remaining_heads = min(
|
256 |
+
closest_power_of_2, num_heads - closest_power_of_2
|
257 |
+
)
|
258 |
+
extra_powers = torch.arange(
|
259 |
+
1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
|
260 |
+
)
|
261 |
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
262 |
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
|
263 |
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
|
264 |
diffs = qa - ka + key_length - query_length
|
265 |
diffs = -diffs.abs()
|
266 |
+
alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
|
267 |
+
1, 1, query_length, key_length
|
268 |
+
)
|
269 |
+
alibi = alibi.expand(batch_size, -1, -1, -1).reshape(
|
270 |
+
-1, query_length, key_length
|
271 |
+
)
|
272 |
return alibi.to(dtype)
|
273 |
+
|
274 |
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
275 |
|
276 |
+
def forward(
|
277 |
+
self: BloomModel,
|
278 |
+
input_ids: Optional[torch.LongTensor] = None,
|
279 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
280 |
+
attention_mask: Optional[torch.Tensor] = None,
|
281 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
282 |
+
head_mask: Optional[torch.LongTensor] = None,
|
283 |
+
inputs_embeds: Optional[torch.LongTensor] = None,
|
284 |
+
use_cache: Optional[bool] = None,
|
285 |
+
output_attentions: Optional[bool] = None,
|
286 |
+
output_hidden_states: Optional[bool] = None,
|
287 |
+
return_dict: Optional[bool] = None,
|
288 |
+
**deprecated_arguments,
|
289 |
+
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
290 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
291 |
+
warnings.warn(
|
292 |
+
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
|
293 |
+
+ "You can safely ignore passing `position_ids`.",
|
294 |
+
FutureWarning,
|
295 |
+
)
|
296 |
if len(deprecated_arguments) > 0:
|
297 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
298 |
+
output_attentions = (
|
299 |
+
output_attentions
|
300 |
+
if output_attentions is not None
|
301 |
+
else self.config.output_attentions
|
302 |
+
)
|
303 |
+
output_hidden_states = (
|
304 |
+
output_hidden_states
|
305 |
+
if output_hidden_states is not None
|
306 |
+
else self.config.output_hidden_states
|
307 |
+
)
|
308 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
309 |
+
return_dict = (
|
310 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
311 |
+
)
|
312 |
if input_ids is not None and inputs_embeds is not None:
|
313 |
+
raise ValueError(
|
314 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
315 |
+
)
|
316 |
elif input_ids is not None:
|
317 |
(batch_size, seq_length) = input_ids.shape
|
318 |
elif inputs_embeds is not None:
|
319 |
(batch_size, seq_length, _) = inputs_embeds.shape
|
320 |
else:
|
321 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
322 |
if past_key_values is None:
|
323 |
past_key_values = tuple([None] * len(self.h))
|
324 |
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
335 |
past_key_values_length = tmp.shape[2]
|
336 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
337 |
if attention_mask is None:
|
338 |
+
attention_mask = torch.ones(
|
339 |
+
(batch_size, seq_length_with_past), device=hidden_states.device
|
340 |
+
)
|
341 |
else:
|
342 |
attention_mask = attention_mask.to(hidden_states.device)
|
343 |
+
alibi = self._build_alibi_tensor(
|
344 |
+
batch_size=batch_size,
|
345 |
+
query_length=seq_length,
|
346 |
+
key_length=seq_length_with_past,
|
347 |
+
dtype=hidden_states.dtype,
|
348 |
+
device=hidden_states.device,
|
349 |
+
)
|
350 |
+
causal_mask = self._prepare_attn_mask(
|
351 |
+
attention_mask,
|
352 |
+
bidirectional_mask,
|
353 |
+
input_shape=(batch_size, seq_length),
|
354 |
+
past_key_values_length=past_key_values_length,
|
355 |
+
)
|
356 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
357 |
if output_hidden_states:
|
358 |
hst = (hidden_states,)
|
359 |
all_hidden_states = all_hidden_states + hst
|
360 |
if self.gradient_checkpointing and self.training:
|
361 |
if use_cache:
|
362 |
+
logger.warning(
|
363 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
364 |
+
)
|
365 |
use_cache = False
|
366 |
|
367 |
def create_custom_forward(module):
|
|
|
368 |
def custom_forward(*inputs):
|
369 |
+
return module(
|
370 |
+
*inputs,
|
371 |
+
use_cache=use_cache,
|
372 |
+
output_attentions=output_attentions,
|
373 |
+
)
|
374 |
+
|
375 |
return custom_forward
|
376 |
+
|
377 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
378 |
+
create_custom_forward(block),
|
379 |
+
hidden_states,
|
380 |
+
alibi,
|
381 |
+
causal_mask,
|
382 |
+
head_mask[i],
|
383 |
+
)
|
384 |
else:
|
385 |
+
outputs = block(
|
386 |
+
hidden_states,
|
387 |
+
layer_past=layer_past,
|
388 |
+
attention_mask=causal_mask,
|
389 |
+
head_mask=head_mask[i],
|
390 |
+
use_cache=use_cache,
|
391 |
+
output_attentions=output_attentions,
|
392 |
+
alibi=alibi,
|
393 |
+
)
|
394 |
hidden_states = outputs[0]
|
395 |
if use_cache is True:
|
396 |
presents = presents + (outputs[1],)
|
|
|
402 |
hst = (hidden_states,)
|
403 |
all_hidden_states = all_hidden_states + hst
|
404 |
if not return_dict:
|
405 |
+
return tuple(
|
406 |
+
(
|
407 |
+
v
|
408 |
+
for v in [
|
409 |
+
hidden_states,
|
410 |
+
presents,
|
411 |
+
all_hidden_states,
|
412 |
+
all_self_attentions,
|
413 |
+
]
|
414 |
+
if v is not None
|
415 |
+
)
|
416 |
+
)
|
417 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
418 |
+
last_hidden_state=hidden_states,
|
419 |
+
past_key_values=presents,
|
420 |
+
hidden_states=all_hidden_states,
|
421 |
+
attentions=all_self_attentions,
|
422 |
+
)
|
423 |
+
|
424 |
+
setattr(
|
425 |
+
model.transformer,
|
426 |
+
"_prepare_attn_mask",
|
427 |
+
MethodType(_prepare_attn_mask, model.transformer),
|
428 |
+
)
|
429 |
+
setattr(
|
430 |
+
model.transformer,
|
431 |
+
"_build_alibi_tensor",
|
432 |
+
MethodType(_build_alibi_tensor, model.transformer),
|
433 |
+
)
|
434 |
+
setattr(model.transformer, "forward", MethodType(forward, model.transformer))
|
435 |
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
436 |
|
437 |
+
def forward(
|
438 |
+
self: BloomForCausalLM,
|
439 |
+
input_ids: Optional[torch.LongTensor] = None,
|
440 |
+
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
442 |
+
bidirectional_mask: Optional[torch.Tensor] = None,
|
443 |
+
head_mask: Optional[torch.Tensor] = None,
|
444 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
445 |
+
labels: Optional[torch.Tensor] = None,
|
446 |
+
use_cache: Optional[bool] = None,
|
447 |
+
output_attentions: Optional[bool] = None,
|
448 |
+
output_hidden_states: Optional[bool] = None,
|
449 |
+
return_dict: Optional[bool] = None,
|
450 |
+
**deprecated_arguments,
|
451 |
+
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
452 |
"""Replacement forward method for BloomCausalLM."""
|
453 |
+
if deprecated_arguments.pop("position_ids", False) is not False:
|
454 |
+
warnings.warn(
|
455 |
+
"`position_ids` have no functionality in BLOOM and will be removed "
|
456 |
+
+ "in v5.0.0. You can safely ignore passing `position_ids`.",
|
457 |
+
FutureWarning,
|
458 |
+
)
|
459 |
if len(deprecated_arguments) > 0:
|
460 |
+
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
461 |
+
return_dict = (
|
462 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
463 |
+
)
|
464 |
+
transformer_outputs = self.transformer(
|
465 |
+
input_ids,
|
466 |
+
past_key_values=past_key_values,
|
467 |
+
attention_mask=attention_mask,
|
468 |
+
bidirectional_mask=bidirectional_mask,
|
469 |
+
head_mask=head_mask,
|
470 |
+
inputs_embeds=inputs_embeds,
|
471 |
+
use_cache=use_cache,
|
472 |
+
output_attentions=output_attentions,
|
473 |
+
output_hidden_states=output_hidden_states,
|
474 |
+
return_dict=return_dict,
|
475 |
+
)
|
476 |
hidden_states = transformer_outputs[0]
|
477 |
lm_logits = self.lm_head(hidden_states)
|
478 |
loss = None
|
|
|
481 |
shift_labels = labels[..., 1:].contiguous()
|
482 |
(batch_size, seq_length, vocab_size) = shift_logits.shape
|
483 |
loss_fct = CrossEntropyLoss()
|
484 |
+
loss = loss_fct(
|
485 |
+
shift_logits.view(batch_size * seq_length, vocab_size),
|
486 |
+
shift_labels.view(batch_size * seq_length),
|
487 |
+
)
|
488 |
if not return_dict:
|
489 |
output = (lm_logits,) + transformer_outputs[1:]
|
490 |
return (loss,) + output if loss is not None else output
|
491 |
+
return CausalLMOutputWithCrossAttentions(
|
492 |
+
loss=loss,
|
493 |
+
logits=lm_logits,
|
494 |
+
past_key_values=transformer_outputs.past_key_values,
|
495 |
+
hidden_states=transformer_outputs.hidden_states,
|
496 |
+
attentions=transformer_outputs.attentions,
|
497 |
+
)
|
498 |
+
|
499 |
+
def prepare_inputs_for_generation(
|
500 |
+
self: BloomForCausalLM,
|
501 |
+
input_ids: torch.LongTensor,
|
502 |
+
past: Optional[torch.Tensor] = None,
|
503 |
+
attention_mask: Optional[torch.Tensor] = None,
|
504 |
+
**kwargs,
|
505 |
+
) -> dict:
|
506 |
if past:
|
507 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
508 |
bidirectional_mask = None
|
|
|
510 |
past = self._convert_to_bloom_cache(past)
|
511 |
else:
|
512 |
bidirectional_mask = torch.ones_like(input_ids)
|
513 |
+
return {
|
514 |
+
"input_ids": input_ids,
|
515 |
+
"past_key_values": past,
|
516 |
+
"use_cache": True,
|
517 |
+
"attention_mask": attention_mask,
|
518 |
+
"bidirectional_mask": bidirectional_mask,
|
519 |
+
}
|
520 |
+
|
521 |
+
setattr(model, "forward", MethodType(forward, model))
|
522 |
+
setattr(
|
523 |
+
model,
|
524 |
+
"prepare_inputs_for_generation",
|
525 |
+
MethodType(prepare_inputs_for_generation, model),
|
526 |
+
)
|
527 |
+
setattr(model, "_prefix_lm_converted", True)
|
528 |
return model
|
529 |
|
530 |
+
|
531 |
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
532 |
"""Converts an OPT Causal LM to a Prefix LM.
|
533 |
|
|
|
536 |
|
537 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
538 |
"""
|
539 |
+
if hasattr(model, "_prefix_lm_converted"):
|
540 |
return model
|
541 |
assert isinstance(model, OPTForCausalLM)
|
542 |
+
assert (
|
543 |
+
model.config.add_cross_attention == False
|
544 |
+
), "Only supports OPT decoder-only models"
|
545 |
+
setattr(model, "_original_forward", getattr(model, "forward"))
|
546 |
+
setattr(model, "_original_generate", getattr(model, "generate"))
|
547 |
model.model.decoder.bidirectional_mask = None
|
548 |
|
549 |
+
def _prepare_decoder_attention_mask(
|
550 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
551 |
+
):
|
552 |
combined_attention_mask = None
|
553 |
if input_shape[-1] > 1:
|
554 |
+
if self.bidirectional_mask == "g":
|
555 |
(bsz, src_length) = input_shape
|
556 |
+
combined_attention_mask = torch.zeros(
|
557 |
+
(bsz, 1, src_length, src_length + past_key_values_length),
|
558 |
+
dtype=inputs_embeds.dtype,
|
559 |
+
device=inputs_embeds.device,
|
560 |
+
)
|
561 |
else:
|
562 |
+
combined_attention_mask = _make_causal_mask_opt(
|
563 |
+
input_shape,
|
564 |
+
inputs_embeds.dtype,
|
565 |
+
past_key_values_length=past_key_values_length,
|
566 |
+
).to(inputs_embeds.device)
|
567 |
if self.bidirectional_mask is not None:
|
568 |
assert attention_mask.shape == self.bidirectional_mask.shape
|
569 |
+
expanded_bidirectional_mask = _expand_mask_opt(
|
570 |
+
self.bidirectional_mask,
|
571 |
+
inputs_embeds.dtype,
|
572 |
+
tgt_len=input_shape[-1],
|
573 |
+
).to(inputs_embeds.device)
|
574 |
+
combined_attention_mask = torch.maximum(
|
575 |
+
expanded_bidirectional_mask, combined_attention_mask
|
576 |
+
)
|
577 |
if attention_mask is not None:
|
578 |
+
expanded_attn_mask = _expand_mask_opt(
|
579 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
580 |
+
).to(inputs_embeds.device)
|
581 |
+
combined_attention_mask = (
|
582 |
+
expanded_attn_mask
|
583 |
+
if combined_attention_mask is None
|
584 |
+
else expanded_attn_mask + combined_attention_mask
|
585 |
+
)
|
586 |
return combined_attention_mask
|
|
|
|
|
|
|
587 |
|
588 |
+
setattr(
|
589 |
+
model.model.decoder,
|
590 |
+
"_prepare_decoder_attention_mask",
|
591 |
+
MethodType(_prepare_decoder_attention_mask, model.model.decoder),
|
592 |
+
)
|
593 |
+
|
594 |
+
def forward(
|
595 |
+
self: OPTForCausalLM,
|
596 |
+
input_ids: Optional[torch.LongTensor] = None,
|
597 |
+
attention_mask: Optional[torch.Tensor] = None,
|
598 |
+
bidirectional_mask: Optional[torch.ByteTensor] = None,
|
599 |
+
head_mask: Optional[torch.Tensor] = None,
|
600 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
601 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
602 |
+
labels: Optional[torch.LongTensor] = None,
|
603 |
+
use_cache: Optional[bool] = None,
|
604 |
+
output_attentions: Optional[bool] = None,
|
605 |
+
output_hidden_states: Optional[bool] = None,
|
606 |
+
return_dict: Optional[bool] = None,
|
607 |
+
):
|
608 |
def call_og_forward():
|
609 |
+
return self._original_forward(
|
610 |
+
input_ids=input_ids,
|
611 |
+
attention_mask=attention_mask,
|
612 |
+
head_mask=head_mask,
|
613 |
+
past_key_values=past_key_values,
|
614 |
+
inputs_embeds=inputs_embeds,
|
615 |
+
labels=labels,
|
616 |
+
use_cache=use_cache,
|
617 |
+
output_attentions=output_attentions,
|
618 |
+
output_hidden_states=output_hidden_states,
|
619 |
+
return_dict=return_dict,
|
620 |
+
)
|
621 |
+
|
622 |
if bidirectional_mask is None:
|
623 |
return call_og_forward()
|
624 |
self.model.decoder.bidirectional_mask = bidirectional_mask
|
|
|
632 |
|
633 |
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
|
634 |
"""Wraps original generate to enable PrefixLM-style attention."""
|
635 |
+
self.model.decoder.bidirectional_mask = "g"
|
636 |
try:
|
637 |
output = self._original_generate(*args, **kwargs)
|
638 |
except:
|
|
|
640 |
raise
|
641 |
self.model.decoder.bidirectional_mask = None
|
642 |
return output
|
643 |
+
|
644 |
+
setattr(model, "forward", MethodType(forward, model))
|
645 |
+
setattr(model, "generate", MethodType(generate, model))
|
646 |
+
setattr(model, "_prefix_lm_converted", True)
|
647 |
return model
|
648 |
+
|
649 |
+
|
650 |
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
|
651 |
+
CAUSAL_LM_TYPES = Union[
|
652 |
+
GPT2LMHeadModel,
|
653 |
+
GPTJForCausalLM,
|
654 |
+
GPTNeoForCausalLM,
|
655 |
+
GPTNeoXForCausalLM,
|
656 |
+
BloomForCausalLM,
|
657 |
+
OPTForCausalLM,
|
658 |
+
]
|
659 |
+
|
660 |
|
661 |
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
662 |
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
|
|
722 |
elif isinstance(model, OPTForCausalLM):
|
723 |
return _convert_opt_causal_lm_to_prefix_lm(model)
|
724 |
else:
|
725 |
+
raise TypeError(
|
726 |
+
f"Cannot convert model to Prefix LM. "
|
727 |
+
+ f"Model does not belong to set of supported HF models:"
|
728 |
+
+ f"\n{_SUPPORTED_HF_MODELS}"
|
729 |
+
)
|
730 |
+
|
731 |
|
732 |
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
733 |
"""Attempts to add bidirectional_mask to batch if missing.
|
|
|
735 |
Raises:
|
736 |
KeyError if bidirectional_mask is missing and can't be inferred
|
737 |
"""
|
738 |
+
if "bidirectional_mask" not in batch:
|
739 |
+
if batch.get("mode", None) == "icl_task":
|
740 |
+
batch["bidirectional_mask"] = batch["attention_mask"].clone()
|
741 |
+
for i, continuation_indices in enumerate(batch["continuation_indices"]):
|
742 |
+
batch["bidirectional_mask"][i, continuation_indices] = 0
|
743 |
+
elif "labels" in batch and "attention_mask" in batch:
|
744 |
+
batch["bidirectional_mask"] = torch.logical_and(
|
745 |
+
torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
|
746 |
+
).type_as(batch["attention_mask"])
|
747 |
else:
|
748 |
+
raise KeyError(
|
749 |
+
"No bidirectional_mask in batch and not sure how to construct one."
|
750 |
+
)
|
model/llava/model/mpt/meta_init_context.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from contextlib import contextmanager
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
|
|
5 |
@contextmanager
|
6 |
-
def init_empty_weights(include_buffers: bool=False):
|
7 |
"""Meta initialization context manager.
|
8 |
|
9 |
A context manager under which models are initialized with all parameters
|
@@ -30,11 +32,12 @@ def init_empty_weights(include_buffers: bool=False):
|
|
30 |
|
31 |
</Tip>
|
32 |
"""
|
33 |
-
with init_on_device(torch.device(
|
34 |
yield f
|
35 |
|
|
|
36 |
@contextmanager
|
37 |
-
def init_on_device(device: torch.device, include_buffers: bool=False):
|
38 |
"""Device initialization context manager.
|
39 |
|
40 |
A context manager under which models are initialized with all parameters
|
@@ -62,33 +65,47 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
|
|
62 |
if param is not None:
|
63 |
param_cls = type(module._parameters[name])
|
64 |
kwargs = module._parameters[name].__dict__
|
65 |
-
module._parameters[name] = param_cls(
|
|
|
|
|
66 |
|
67 |
def register_empty_buffer(module, name, buffer):
|
68 |
old_register_buffer(module, name, buffer)
|
69 |
if buffer is not None:
|
70 |
module._buffers[name] = module._buffers[name].to(device)
|
|
|
71 |
if include_buffers:
|
72 |
-
tensor_constructors_to_patch = {
|
|
|
|
|
|
|
73 |
else:
|
74 |
tensor_constructors_to_patch = {}
|
75 |
|
76 |
def patch_tensor_constructor(fn):
|
77 |
-
|
78 |
def wrapper(*args, **kwargs):
|
79 |
-
kwargs[
|
80 |
return fn(*args, **kwargs)
|
|
|
81 |
return wrapper
|
|
|
82 |
try:
|
83 |
nn.Module.register_parameter = register_empty_parameter
|
84 |
if include_buffers:
|
85 |
nn.Module.register_buffer = register_empty_buffer
|
86 |
for torch_function_name in tensor_constructors_to_patch.keys():
|
87 |
-
setattr(
|
|
|
|
|
|
|
|
|
88 |
yield
|
89 |
finally:
|
90 |
nn.Module.register_parameter = old_register_parameter
|
91 |
if include_buffers:
|
92 |
nn.Module.register_buffer = old_register_buffer
|
93 |
-
for (
|
94 |
-
|
|
|
|
|
|
|
|
1 |
from contextlib import contextmanager
|
2 |
+
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
+
|
7 |
@contextmanager
|
8 |
+
def init_empty_weights(include_buffers: bool = False):
|
9 |
"""Meta initialization context manager.
|
10 |
|
11 |
A context manager under which models are initialized with all parameters
|
|
|
32 |
|
33 |
</Tip>
|
34 |
"""
|
35 |
+
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
36 |
yield f
|
37 |
|
38 |
+
|
39 |
@contextmanager
|
40 |
+
def init_on_device(device: torch.device, include_buffers: bool = False):
|
41 |
"""Device initialization context manager.
|
42 |
|
43 |
A context manager under which models are initialized with all parameters
|
|
|
65 |
if param is not None:
|
66 |
param_cls = type(module._parameters[name])
|
67 |
kwargs = module._parameters[name].__dict__
|
68 |
+
module._parameters[name] = param_cls(
|
69 |
+
module._parameters[name].to(device), **kwargs
|
70 |
+
)
|
71 |
|
72 |
def register_empty_buffer(module, name, buffer):
|
73 |
old_register_buffer(module, name, buffer)
|
74 |
if buffer is not None:
|
75 |
module._buffers[name] = module._buffers[name].to(device)
|
76 |
+
|
77 |
if include_buffers:
|
78 |
+
tensor_constructors_to_patch = {
|
79 |
+
torch_function_name: getattr(torch, torch_function_name)
|
80 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
81 |
+
}
|
82 |
else:
|
83 |
tensor_constructors_to_patch = {}
|
84 |
|
85 |
def patch_tensor_constructor(fn):
|
|
|
86 |
def wrapper(*args, **kwargs):
|
87 |
+
kwargs["device"] = device
|
88 |
return fn(*args, **kwargs)
|
89 |
+
|
90 |
return wrapper
|
91 |
+
|
92 |
try:
|
93 |
nn.Module.register_parameter = register_empty_parameter
|
94 |
if include_buffers:
|
95 |
nn.Module.register_buffer = register_empty_buffer
|
96 |
for torch_function_name in tensor_constructors_to_patch.keys():
|
97 |
+
setattr(
|
98 |
+
torch,
|
99 |
+
torch_function_name,
|
100 |
+
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
101 |
+
)
|
102 |
yield
|
103 |
finally:
|
104 |
nn.Module.register_parameter = old_register_parameter
|
105 |
if include_buffers:
|
106 |
nn.Module.register_buffer = old_register_buffer
|
107 |
+
for (
|
108 |
+
torch_function_name,
|
109 |
+
old_torch_function,
|
110 |
+
) in tensor_constructors_to_patch.items():
|
111 |
+
setattr(torch, torch_function_name, old_torch_function)
|
model/llava/model/mpt/modeling_mpt.py
CHANGED
@@ -5,68 +5,95 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
5 |
import math
|
6 |
import warnings
|
7 |
from typing import List, Optional, Tuple, Union
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
-
from transformers import PreTrainedModel, PreTrainedTokenizer,
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
15 |
-
from .norm import NORM_CLASS_REGISTRY
|
16 |
from .configuration_mpt import MPTConfig
|
17 |
-
from .
|
18 |
-
|
19 |
from .meta_init_context import init_empty_weights
|
|
|
20 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
|
|
21 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
22 |
|
23 |
from transformers.utils import logging
|
|
|
24 |
logger = logging.get_logger(__name__)
|
25 |
|
|
|
26 |
class MPTPreTrainedModel(PreTrainedModel):
|
27 |
config_class = MPTConfig
|
28 |
-
base_model_prefix =
|
29 |
|
30 |
-
class MPTModel(MPTPreTrainedModel):
|
31 |
|
|
|
32 |
def __init__(self, config: MPTConfig):
|
33 |
config._validate_config()
|
34 |
super().__init__(config)
|
35 |
-
self.attn_impl = config.attn_config[
|
36 |
-
self.prefix_lm = config.attn_config[
|
37 |
-
self.attn_uses_sequence_id = config.attn_config[
|
38 |
-
self.alibi = config.attn_config[
|
39 |
-
self.alibi_bias_max = config.attn_config[
|
40 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
41 |
-
norm_options =
|
42 |
-
raise NotImplementedError(
|
|
|
|
|
43 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
44 |
self.embedding_fraction = config.embedding_fraction
|
45 |
-
self.wte = nn.Embedding(
|
|
|
|
|
46 |
if not self.alibi:
|
47 |
-
self.wpe = nn.Embedding(
|
|
|
|
|
48 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
49 |
-
self.blocks = nn.ModuleList(
|
|
|
|
|
|
|
|
|
|
|
50 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
51 |
-
if config.init_device !=
|
52 |
self.apply(self.param_init_fn)
|
53 |
self.is_causal = not self.prefix_lm
|
54 |
self._attn_bias_initialized = False
|
55 |
self.attn_bias = None
|
56 |
-
self.attn_bias_shape = attn_bias_shape(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if config.no_bias:
|
58 |
for module in self.modules():
|
59 |
-
if hasattr(module,
|
60 |
if config.verbose:
|
61 |
-
warnings.warn(f
|
62 |
-
module.register_parameter(
|
63 |
if config.verbose and config.verbose > 2:
|
64 |
print(self)
|
65 |
-
if
|
66 |
-
self.config.init_config[
|
67 |
-
if self.config.init_config[
|
68 |
-
init_fn_name = self.config.init_config[
|
69 |
-
warnings.warn(f
|
70 |
self.gradient_checkpointing = False
|
71 |
|
72 |
def get_input_embeddings(self):
|
@@ -76,13 +103,30 @@ class MPTModel(MPTPreTrainedModel):
|
|
76 |
self.wte = value
|
77 |
|
78 |
@torch.no_grad()
|
79 |
-
def _attn_bias(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
if not self._attn_bias_initialized:
|
81 |
if self.attn_bias_shape:
|
82 |
-
self.attn_bias = torch.zeros(
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
self._attn_bias_initialized = True
|
85 |
-
if self.attn_impl ==
|
86 |
return (self.attn_bias, attention_mask)
|
87 |
if self.attn_bias is not None:
|
88 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
@@ -101,38 +145,71 @@ class MPTModel(MPTPreTrainedModel):
|
|
101 |
else:
|
102 |
attn_bias = attn_bias[:, :, :, -s_k:]
|
103 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
104 |
-
raise ValueError(
|
|
|
|
|
|
|
105 |
min_val = torch.finfo(attn_bias.dtype).min
|
106 |
-
attn_bias = attn_bias.masked_fill(
|
|
|
|
|
107 |
return (attn_bias, None)
|
108 |
|
109 |
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
110 |
(s_k, s_q) = attn_bias.shape[-2:]
|
111 |
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
112 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
113 |
seq_len = prefix_mask.shape[-1]
|
114 |
if seq_len > self.config.max_seq_len:
|
115 |
-
raise ValueError(
|
|
|
|
|
116 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
117 |
-
causal = torch.tril(
|
|
|
|
|
118 |
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
119 |
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
120 |
min_val = torch.finfo(attn_bias.dtype).min
|
121 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
122 |
return attn_bias
|
123 |
|
124 |
-
def _apply_sequence_id(
|
|
|
|
|
125 |
seq_len = sequence_id.shape[-1]
|
126 |
if seq_len > self.config.max_seq_len:
|
127 |
-
raise ValueError(
|
|
|
|
|
128 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
129 |
-
cannot_attend = torch.logical_not(
|
|
|
|
|
130 |
min_val = torch.finfo(attn_bias.dtype).min
|
131 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
132 |
return attn_bias
|
133 |
|
134 |
-
def forward(
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
137 |
|
138 |
if self.gradient_checkpointing and self.training:
|
@@ -146,21 +223,41 @@ class MPTModel(MPTPreTrainedModel):
|
|
146 |
if prefix_mask is not None:
|
147 |
prefix_mask = prefix_mask.bool()
|
148 |
if not return_dict:
|
149 |
-
raise NotImplementedError(
|
|
|
|
|
150 |
if output_attentions:
|
151 |
-
raise NotImplementedError(
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
if self.prefix_lm and prefix_mask is None:
|
155 |
-
raise ValueError(
|
|
|
|
|
156 |
if self.training:
|
157 |
if self.attn_uses_sequence_id and sequence_id is None:
|
158 |
-
raise ValueError(
|
|
|
|
|
|
|
159 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
160 |
-
warnings.warn(
|
|
|
|
|
|
|
161 |
if input_ids is not None:
|
162 |
S = input_ids.size(1)
|
163 |
-
assert
|
|
|
|
|
164 |
tok_emb = self.wte(input_ids)
|
165 |
else:
|
166 |
assert tok_emb is not None
|
@@ -171,45 +268,85 @@ class MPTModel(MPTPreTrainedModel):
|
|
171 |
past_position = 0
|
172 |
if past_key_values is not None:
|
173 |
if len(past_key_values) != self.config.n_layers:
|
174 |
-
raise ValueError(
|
|
|
|
|
|
|
175 |
past_position = past_key_values[0][0].size(1)
|
176 |
if S + past_position > self.config.max_seq_len:
|
177 |
-
raise ValueError(
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if attention_mask is not None:
|
180 |
-
pos = torch.clamp(
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
pos_emb = self.wpe(pos)
|
182 |
x = tok_emb + pos_emb
|
183 |
if self.embedding_fraction == 1:
|
184 |
x = self.emb_drop(x)
|
185 |
else:
|
186 |
-
x_shrunk = x * self.embedding_fraction + x.detach() * (
|
|
|
|
|
187 |
assert isinstance(self.emb_drop, nn.Module)
|
188 |
x = self.emb_drop(x_shrunk)
|
189 |
-
(attn_bias, attention_mask) = self._attn_bias(
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
if use_cache and past_key_values is None:
|
191 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
192 |
all_hidden_states = () if output_hidden_states else None
|
193 |
-
for
|
194 |
if output_hidden_states:
|
195 |
assert all_hidden_states is not None
|
196 |
all_hidden_states = all_hidden_states + (x,)
|
197 |
-
past_key_value =
|
|
|
|
|
198 |
if self.gradient_checkpointing and self.training:
|
199 |
(x, past_key_value) = torch.utils.checkpoint.checkpoint(
|
200 |
-
block,
|
201 |
-
x, past_key_value, attn_bias, attention_mask, self.is_causal
|
202 |
)
|
203 |
else:
|
204 |
-
(x, past_key_value) = block(
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if past_key_values is not None:
|
206 |
past_key_values[b_idx] = past_key_value
|
207 |
x = self.norm_f(x)
|
208 |
-
return BaseModelOutputWithPast(
|
|
|
|
|
|
|
|
|
209 |
|
210 |
def param_init_fn(self, module):
|
211 |
-
init_fn_name = self.config.init_config[
|
212 |
-
MODEL_INIT_REGISTRY[init_fn_name](
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
def fsdp_wrap_fn(self, module):
|
215 |
return isinstance(module, MPTBlock)
|
@@ -217,21 +354,23 @@ class MPTModel(MPTPreTrainedModel):
|
|
217 |
def activation_checkpointing_fn(self, module):
|
218 |
return isinstance(module, MPTBlock)
|
219 |
|
220 |
-
class MPTForCausalLM(MPTPreTrainedModel):
|
221 |
|
|
|
222 |
def __init__(self, config: MPTConfig):
|
223 |
super().__init__(config)
|
224 |
if not config.tie_word_embeddings:
|
225 |
-
raise ValueError(
|
226 |
self.transformer = MPTModel(config)
|
227 |
self.logit_scale = None
|
228 |
if config.logit_scale is not None:
|
229 |
logit_scale = config.logit_scale
|
230 |
if isinstance(logit_scale, str):
|
231 |
-
if logit_scale ==
|
232 |
logit_scale = 1 / math.sqrt(config.d_model)
|
233 |
else:
|
234 |
-
raise ValueError(
|
|
|
|
|
235 |
self.logit_scale = logit_scale
|
236 |
|
237 |
def get_input_embeddings(self):
|
@@ -252,25 +391,63 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
252 |
def get_decoder(self):
|
253 |
return self.transformer
|
254 |
|
255 |
-
def forward(
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
258 |
-
outputs = self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
260 |
if self.logit_scale is not None:
|
261 |
if self.logit_scale == 0:
|
262 |
-
warnings.warn(
|
|
|
|
|
263 |
logits *= self.logit_scale
|
264 |
loss = None
|
265 |
if labels is not None:
|
266 |
labels = torch.roll(labels, shifts=-1)
|
267 |
labels[:, -1] = -100
|
268 |
-
loss = F.cross_entropy(
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
def param_init_fn(self, module):
|
272 |
-
init_fn_name = self.config.init_config[
|
273 |
-
MODEL_INIT_REGISTRY[init_fn_name](
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
def fsdp_wrap_fn(self, module):
|
276 |
return isinstance(module, MPTBlock)
|
@@ -278,12 +455,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
278 |
def activation_checkpointing_fn(self, module):
|
279 |
return isinstance(module, MPTBlock)
|
280 |
|
281 |
-
def prepare_inputs_for_generation(
|
|
|
|
|
282 |
if inputs_embeds is not None:
|
283 |
-
raise NotImplementedError(
|
284 |
-
attention_mask = kwargs[
|
285 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
286 |
-
raise NotImplementedError(
|
|
|
|
|
287 |
if self.transformer.attn_uses_sequence_id and self.training:
|
288 |
sequence_id = torch.zeros_like(input_ids[:1])
|
289 |
else:
|
@@ -292,11 +473,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
292 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
293 |
if self.transformer.prefix_lm:
|
294 |
prefix_mask = torch.ones_like(attention_mask)
|
295 |
-
if kwargs.get(
|
296 |
-
raise NotImplementedError(
|
|
|
|
|
297 |
else:
|
298 |
prefix_mask = None
|
299 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
@staticmethod
|
302 |
def _reorder_cache(past_key_values, beam_idx):
|
@@ -307,5 +497,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
307 |
"""
|
308 |
reordered_past = []
|
309 |
for layer_past in past_key_values:
|
310 |
-
reordered_past += [
|
311 |
-
|
|
|
|
|
|
|
|
|
|
5 |
import math
|
6 |
import warnings
|
7 |
from typing import List, Optional, Tuple, Union
|
8 |
+
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
+
from transformers import (PreTrainedModel, PreTrainedTokenizer,
|
13 |
+
PreTrainedTokenizerFast)
|
14 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
15 |
+
CausalLMOutputWithPast)
|
16 |
+
|
17 |
+
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
18 |
from .attention import attn_bias_shape, build_attn_bias
|
19 |
from .blocks import MPTBlock
|
|
|
20 |
from .configuration_mpt import MPTConfig
|
21 |
+
from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing,
|
22 |
+
convert_hf_causal_lm_to_prefix_lm)
|
23 |
from .meta_init_context import init_empty_weights
|
24 |
+
from .norm import NORM_CLASS_REGISTRY
|
25 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
26 |
+
|
27 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
28 |
|
29 |
from transformers.utils import logging
|
30 |
+
|
31 |
logger = logging.get_logger(__name__)
|
32 |
|
33 |
+
|
34 |
class MPTPreTrainedModel(PreTrainedModel):
|
35 |
config_class = MPTConfig
|
36 |
+
base_model_prefix = "model"
|
37 |
|
|
|
38 |
|
39 |
+
class MPTModel(MPTPreTrainedModel):
|
40 |
def __init__(self, config: MPTConfig):
|
41 |
config._validate_config()
|
42 |
super().__init__(config)
|
43 |
+
self.attn_impl = config.attn_config["attn_impl"]
|
44 |
+
self.prefix_lm = config.attn_config["prefix_lm"]
|
45 |
+
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
46 |
+
self.alibi = config.attn_config["alibi"]
|
47 |
+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
48 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
49 |
+
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
50 |
+
raise NotImplementedError(
|
51 |
+
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
52 |
+
)
|
53 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
54 |
self.embedding_fraction = config.embedding_fraction
|
55 |
+
self.wte = nn.Embedding(
|
56 |
+
config.vocab_size, config.d_model, device=config.init_device
|
57 |
+
)
|
58 |
if not self.alibi:
|
59 |
+
self.wpe = nn.Embedding(
|
60 |
+
config.max_seq_len, config.d_model, device=config.init_device
|
61 |
+
)
|
62 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
63 |
+
self.blocks = nn.ModuleList(
|
64 |
+
[
|
65 |
+
MPTBlock(device=config.init_device, **config.to_dict())
|
66 |
+
for _ in range(config.n_layers)
|
67 |
+
]
|
68 |
+
)
|
69 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
70 |
+
if config.init_device != "meta":
|
71 |
self.apply(self.param_init_fn)
|
72 |
self.is_causal = not self.prefix_lm
|
73 |
self._attn_bias_initialized = False
|
74 |
self.attn_bias = None
|
75 |
+
self.attn_bias_shape = attn_bias_shape(
|
76 |
+
self.attn_impl,
|
77 |
+
config.n_heads,
|
78 |
+
config.max_seq_len,
|
79 |
+
self.alibi,
|
80 |
+
prefix_lm=self.prefix_lm,
|
81 |
+
causal=self.is_causal,
|
82 |
+
use_sequence_id=self.attn_uses_sequence_id,
|
83 |
+
)
|
84 |
if config.no_bias:
|
85 |
for module in self.modules():
|
86 |
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
87 |
if config.verbose:
|
88 |
+
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
89 |
+
module.register_parameter("bias", None)
|
90 |
if config.verbose and config.verbose > 2:
|
91 |
print(self)
|
92 |
+
if "verbose" not in self.config.init_config:
|
93 |
+
self.config.init_config["verbose"] = self.config.verbose
|
94 |
+
if self.config.init_config["verbose"] > 1:
|
95 |
+
init_fn_name = self.config.init_config["name"]
|
96 |
+
warnings.warn(f"Using {init_fn_name} initialization.")
|
97 |
self.gradient_checkpointing = False
|
98 |
|
99 |
def get_input_embeddings(self):
|
|
|
103 |
self.wte = value
|
104 |
|
105 |
@torch.no_grad()
|
106 |
+
def _attn_bias(
|
107 |
+
self,
|
108 |
+
device,
|
109 |
+
dtype,
|
110 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
111 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
112 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
113 |
+
):
|
114 |
if not self._attn_bias_initialized:
|
115 |
if self.attn_bias_shape:
|
116 |
+
self.attn_bias = torch.zeros(
|
117 |
+
self.attn_bias_shape, device=device, dtype=dtype
|
118 |
+
)
|
119 |
+
self.attn_bias = build_attn_bias(
|
120 |
+
self.attn_impl,
|
121 |
+
self.attn_bias,
|
122 |
+
self.config.n_heads,
|
123 |
+
self.config.max_seq_len,
|
124 |
+
causal=self.is_causal,
|
125 |
+
alibi=self.alibi,
|
126 |
+
alibi_bias_max=self.alibi_bias_max,
|
127 |
+
)
|
128 |
self._attn_bias_initialized = True
|
129 |
+
if self.attn_impl == "flash":
|
130 |
return (self.attn_bias, attention_mask)
|
131 |
if self.attn_bias is not None:
|
132 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
|
|
145 |
else:
|
146 |
attn_bias = attn_bias[:, :, :, -s_k:]
|
147 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
148 |
+
raise ValueError(
|
149 |
+
f"attention_mask shape={attention_mask.shape} "
|
150 |
+
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
151 |
+
)
|
152 |
min_val = torch.finfo(attn_bias.dtype).min
|
153 |
+
attn_bias = attn_bias.masked_fill(
|
154 |
+
~attention_mask.view(-1, 1, 1, s_k), min_val
|
155 |
+
)
|
156 |
return (attn_bias, None)
|
157 |
|
158 |
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
159 |
(s_k, s_q) = attn_bias.shape[-2:]
|
160 |
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
161 |
+
raise ValueError(
|
162 |
+
"attn_bias does not match the expected shape. "
|
163 |
+
+ f"The last two dimensions should both be {self.config.max_length} "
|
164 |
+
+ f"but are {s_k} and {s_q}."
|
165 |
+
)
|
166 |
seq_len = prefix_mask.shape[-1]
|
167 |
if seq_len > self.config.max_seq_len:
|
168 |
+
raise ValueError(
|
169 |
+
f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
170 |
+
)
|
171 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
172 |
+
causal = torch.tril(
|
173 |
+
torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
|
174 |
+
).view(1, 1, seq_len, seq_len)
|
175 |
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
176 |
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
177 |
min_val = torch.finfo(attn_bias.dtype).min
|
178 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
179 |
return attn_bias
|
180 |
|
181 |
+
def _apply_sequence_id(
|
182 |
+
self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
|
183 |
+
):
|
184 |
seq_len = sequence_id.shape[-1]
|
185 |
if seq_len > self.config.max_seq_len:
|
186 |
+
raise ValueError(
|
187 |
+
f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
188 |
+
)
|
189 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
190 |
+
cannot_attend = torch.logical_not(
|
191 |
+
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
192 |
+
).unsqueeze(1)
|
193 |
min_val = torch.finfo(attn_bias.dtype).min
|
194 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
195 |
return attn_bias
|
196 |
|
197 |
+
def forward(
|
198 |
+
self,
|
199 |
+
input_ids: torch.LongTensor,
|
200 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
201 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
202 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
203 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
204 |
+
return_dict: Optional[bool] = None,
|
205 |
+
output_attentions: Optional[bool] = None,
|
206 |
+
output_hidden_states: Optional[bool] = None,
|
207 |
+
use_cache: Optional[bool] = None,
|
208 |
+
tok_emb: Optional[torch.FloatTensor] = None,
|
209 |
+
):
|
210 |
+
return_dict = (
|
211 |
+
return_dict if return_dict is not None else self.config.return_dict
|
212 |
+
)
|
213 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
214 |
|
215 |
if self.gradient_checkpointing and self.training:
|
|
|
223 |
if prefix_mask is not None:
|
224 |
prefix_mask = prefix_mask.bool()
|
225 |
if not return_dict:
|
226 |
+
raise NotImplementedError(
|
227 |
+
"return_dict False is not implemented yet for MPT"
|
228 |
+
)
|
229 |
if output_attentions:
|
230 |
+
raise NotImplementedError(
|
231 |
+
"output_attentions is not implemented yet for MPT"
|
232 |
+
)
|
233 |
+
if (
|
234 |
+
attention_mask is not None
|
235 |
+
and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
236 |
+
and self.training
|
237 |
+
):
|
238 |
+
raise NotImplementedError(
|
239 |
+
"MPT does not support training with left padding."
|
240 |
+
)
|
241 |
if self.prefix_lm and prefix_mask is None:
|
242 |
+
raise ValueError(
|
243 |
+
"prefix_mask is a required argument when MPT is configured with prefix_lm=True."
|
244 |
+
)
|
245 |
if self.training:
|
246 |
if self.attn_uses_sequence_id and sequence_id is None:
|
247 |
+
raise ValueError(
|
248 |
+
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
249 |
+
+ "and the model is in train mode."
|
250 |
+
)
|
251 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
252 |
+
warnings.warn(
|
253 |
+
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
254 |
+
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
255 |
+
)
|
256 |
if input_ids is not None:
|
257 |
S = input_ids.size(1)
|
258 |
+
assert (
|
259 |
+
S <= self.config.max_seq_len
|
260 |
+
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
261 |
tok_emb = self.wte(input_ids)
|
262 |
else:
|
263 |
assert tok_emb is not None
|
|
|
268 |
past_position = 0
|
269 |
if past_key_values is not None:
|
270 |
if len(past_key_values) != self.config.n_layers:
|
271 |
+
raise ValueError(
|
272 |
+
f"past_key_values must provide a past_key_value for each attention "
|
273 |
+
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
274 |
+
)
|
275 |
past_position = past_key_values[0][0].size(1)
|
276 |
if S + past_position > self.config.max_seq_len:
|
277 |
+
raise ValueError(
|
278 |
+
f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
279 |
+
)
|
280 |
+
pos = torch.arange(
|
281 |
+
past_position,
|
282 |
+
S + past_position,
|
283 |
+
dtype=torch.long,
|
284 |
+
device=input_ids.device,
|
285 |
+
).unsqueeze(0)
|
286 |
if attention_mask is not None:
|
287 |
+
pos = torch.clamp(
|
288 |
+
pos
|
289 |
+
- torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
|
290 |
+
:, past_position:
|
291 |
+
],
|
292 |
+
min=0,
|
293 |
+
)
|
294 |
pos_emb = self.wpe(pos)
|
295 |
x = tok_emb + pos_emb
|
296 |
if self.embedding_fraction == 1:
|
297 |
x = self.emb_drop(x)
|
298 |
else:
|
299 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (
|
300 |
+
1 - self.embedding_fraction
|
301 |
+
)
|
302 |
assert isinstance(self.emb_drop, nn.Module)
|
303 |
x = self.emb_drop(x_shrunk)
|
304 |
+
(attn_bias, attention_mask) = self._attn_bias(
|
305 |
+
device=x.device,
|
306 |
+
dtype=x.dtype,
|
307 |
+
attention_mask=attention_mask,
|
308 |
+
prefix_mask=prefix_mask,
|
309 |
+
sequence_id=sequence_id,
|
310 |
+
)
|
311 |
if use_cache and past_key_values is None:
|
312 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
313 |
all_hidden_states = () if output_hidden_states else None
|
314 |
+
for b_idx, block in enumerate(self.blocks):
|
315 |
if output_hidden_states:
|
316 |
assert all_hidden_states is not None
|
317 |
all_hidden_states = all_hidden_states + (x,)
|
318 |
+
past_key_value = (
|
319 |
+
past_key_values[b_idx] if past_key_values is not None else None
|
320 |
+
)
|
321 |
if self.gradient_checkpointing and self.training:
|
322 |
(x, past_key_value) = torch.utils.checkpoint.checkpoint(
|
323 |
+
block, x, past_key_value, attn_bias, attention_mask, self.is_causal
|
|
|
324 |
)
|
325 |
else:
|
326 |
+
(x, past_key_value) = block(
|
327 |
+
x,
|
328 |
+
past_key_value=past_key_value,
|
329 |
+
attn_bias=attn_bias,
|
330 |
+
attention_mask=attention_mask,
|
331 |
+
is_causal=self.is_causal,
|
332 |
+
)
|
333 |
if past_key_values is not None:
|
334 |
past_key_values[b_idx] = past_key_value
|
335 |
x = self.norm_f(x)
|
336 |
+
return BaseModelOutputWithPast(
|
337 |
+
last_hidden_state=x,
|
338 |
+
past_key_values=past_key_values,
|
339 |
+
hidden_states=all_hidden_states,
|
340 |
+
)
|
341 |
|
342 |
def param_init_fn(self, module):
|
343 |
+
init_fn_name = self.config.init_config["name"]
|
344 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
345 |
+
module=module,
|
346 |
+
n_layers=self.config.n_layers,
|
347 |
+
d_model=self.config.d_model,
|
348 |
+
**self.config.init_config,
|
349 |
+
)
|
350 |
|
351 |
def fsdp_wrap_fn(self, module):
|
352 |
return isinstance(module, MPTBlock)
|
|
|
354 |
def activation_checkpointing_fn(self, module):
|
355 |
return isinstance(module, MPTBlock)
|
356 |
|
|
|
357 |
|
358 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
359 |
def __init__(self, config: MPTConfig):
|
360 |
super().__init__(config)
|
361 |
if not config.tie_word_embeddings:
|
362 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
363 |
self.transformer = MPTModel(config)
|
364 |
self.logit_scale = None
|
365 |
if config.logit_scale is not None:
|
366 |
logit_scale = config.logit_scale
|
367 |
if isinstance(logit_scale, str):
|
368 |
+
if logit_scale == "inv_sqrt_d_model":
|
369 |
logit_scale = 1 / math.sqrt(config.d_model)
|
370 |
else:
|
371 |
+
raise ValueError(
|
372 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
373 |
+
)
|
374 |
self.logit_scale = logit_scale
|
375 |
|
376 |
def get_input_embeddings(self):
|
|
|
391 |
def get_decoder(self):
|
392 |
return self.transformer
|
393 |
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
input_ids: torch.LongTensor,
|
397 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
398 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
399 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
400 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
401 |
+
labels: Optional[torch.LongTensor] = None,
|
402 |
+
return_dict: Optional[bool] = None,
|
403 |
+
output_attentions: Optional[bool] = None,
|
404 |
+
output_hidden_states: Optional[bool] = None,
|
405 |
+
use_cache: Optional[bool] = None,
|
406 |
+
):
|
407 |
+
return_dict = (
|
408 |
+
return_dict if return_dict is not None else self.config.return_dict
|
409 |
+
)
|
410 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
411 |
+
outputs = self.transformer(
|
412 |
+
input_ids=input_ids,
|
413 |
+
past_key_values=past_key_values,
|
414 |
+
attention_mask=attention_mask,
|
415 |
+
prefix_mask=prefix_mask,
|
416 |
+
sequence_id=sequence_id,
|
417 |
+
return_dict=return_dict,
|
418 |
+
output_attentions=output_attentions,
|
419 |
+
output_hidden_states=output_hidden_states,
|
420 |
+
use_cache=use_cache,
|
421 |
+
)
|
422 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
423 |
if self.logit_scale is not None:
|
424 |
if self.logit_scale == 0:
|
425 |
+
warnings.warn(
|
426 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
427 |
+
)
|
428 |
logits *= self.logit_scale
|
429 |
loss = None
|
430 |
if labels is not None:
|
431 |
labels = torch.roll(labels, shifts=-1)
|
432 |
labels[:, -1] = -100
|
433 |
+
loss = F.cross_entropy(
|
434 |
+
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
435 |
+
)
|
436 |
+
return CausalLMOutputWithPast(
|
437 |
+
loss=loss,
|
438 |
+
logits=logits,
|
439 |
+
past_key_values=outputs.past_key_values,
|
440 |
+
hidden_states=outputs.hidden_states,
|
441 |
+
)
|
442 |
|
443 |
def param_init_fn(self, module):
|
444 |
+
init_fn_name = self.config.init_config["name"]
|
445 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
446 |
+
module=module,
|
447 |
+
n_layers=self.config.n_layers,
|
448 |
+
d_model=self.config.d_model,
|
449 |
+
**self.config.init_config,
|
450 |
+
)
|
451 |
|
452 |
def fsdp_wrap_fn(self, module):
|
453 |
return isinstance(module, MPTBlock)
|
|
|
455 |
def activation_checkpointing_fn(self, module):
|
456 |
return isinstance(module, MPTBlock)
|
457 |
|
458 |
+
def prepare_inputs_for_generation(
|
459 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
460 |
+
):
|
461 |
if inputs_embeds is not None:
|
462 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
463 |
+
attention_mask = kwargs["attention_mask"].bool()
|
464 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
465 |
+
raise NotImplementedError(
|
466 |
+
"MPT does not support generation with right padding."
|
467 |
+
)
|
468 |
if self.transformer.attn_uses_sequence_id and self.training:
|
469 |
sequence_id = torch.zeros_like(input_ids[:1])
|
470 |
else:
|
|
|
473 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
474 |
if self.transformer.prefix_lm:
|
475 |
prefix_mask = torch.ones_like(attention_mask)
|
476 |
+
if kwargs.get("use_cache") == False:
|
477 |
+
raise NotImplementedError(
|
478 |
+
"MPT with prefix_lm=True does not support use_cache=False."
|
479 |
+
)
|
480 |
else:
|
481 |
prefix_mask = None
|
482 |
+
return {
|
483 |
+
"input_ids": input_ids,
|
484 |
+
"attention_mask": attention_mask,
|
485 |
+
"prefix_mask": prefix_mask,
|
486 |
+
"sequence_id": sequence_id,
|
487 |
+
"past_key_values": past_key_values,
|
488 |
+
"use_cache": kwargs.get("use_cache", True),
|
489 |
+
}
|
490 |
|
491 |
@staticmethod
|
492 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
497 |
"""
|
498 |
reordered_past = []
|
499 |
for layer_past in past_key_values:
|
500 |
+
reordered_past += [
|
501 |
+
tuple(
|
502 |
+
(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
503 |
+
)
|
504 |
+
]
|
505 |
+
return reordered_past
|
model/llava/model/mpt/norm.py
CHANGED
@@ -1,28 +1,55 @@
|
|
1 |
import torch
|
2 |
|
|
|
3 |
def _cast_if_autocast_enabled(tensor):
|
4 |
if torch.is_autocast_enabled():
|
5 |
-
if tensor.device.type ==
|
6 |
dtype = torch.get_autocast_gpu_dtype()
|
7 |
-
elif tensor.device.type ==
|
8 |
dtype = torch.get_autocast_cpu_dtype()
|
9 |
else:
|
10 |
raise NotImplementedError()
|
11 |
return tensor.to(dtype=dtype)
|
12 |
return tensor
|
13 |
|
14 |
-
class LPLayerNorm(torch.nn.LayerNorm):
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def forward(self, x):
|
20 |
module_device = x.device
|
21 |
downcast_x = _cast_if_autocast_enabled(x)
|
22 |
-
downcast_weight =
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
with torch.autocast(enabled=False, device_type=module_device.type):
|
25 |
-
return torch.nn.functional.layer_norm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def rms_norm(x, weight=None, eps=1e-05):
|
28 |
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
@@ -30,27 +57,50 @@ def rms_norm(x, weight=None, eps=1e-05):
|
|
30 |
return output * weight
|
31 |
return output
|
32 |
|
33 |
-
class RMSNorm(torch.nn.Module):
|
34 |
|
35 |
-
|
|
|
|
|
|
|
36 |
super().__init__()
|
37 |
self.eps = eps
|
38 |
if weight:
|
39 |
-
self.weight = torch.nn.Parameter(
|
|
|
|
|
40 |
else:
|
41 |
-
self.register_parameter(
|
42 |
|
43 |
def forward(self, x):
|
44 |
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
45 |
|
46 |
-
class LPRMSNorm(RMSNorm):
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def forward(self, x):
|
52 |
downcast_x = _cast_if_autocast_enabled(x)
|
53 |
-
downcast_weight =
|
|
|
|
|
|
|
|
|
54 |
with torch.autocast(enabled=False, device_type=x.device.type):
|
55 |
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
|
3 |
+
|
4 |
def _cast_if_autocast_enabled(tensor):
|
5 |
if torch.is_autocast_enabled():
|
6 |
+
if tensor.device.type == "cuda":
|
7 |
dtype = torch.get_autocast_gpu_dtype()
|
8 |
+
elif tensor.device.type == "cpu":
|
9 |
dtype = torch.get_autocast_cpu_dtype()
|
10 |
else:
|
11 |
raise NotImplementedError()
|
12 |
return tensor.to(dtype=dtype)
|
13 |
return tensor
|
14 |
|
|
|
15 |
|
16 |
+
class LPLayerNorm(torch.nn.LayerNorm):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
normalized_shape,
|
20 |
+
eps=1e-05,
|
21 |
+
elementwise_affine=True,
|
22 |
+
device=None,
|
23 |
+
dtype=None,
|
24 |
+
):
|
25 |
+
super().__init__(
|
26 |
+
normalized_shape=normalized_shape,
|
27 |
+
eps=eps,
|
28 |
+
elementwise_affine=elementwise_affine,
|
29 |
+
device=device,
|
30 |
+
dtype=dtype,
|
31 |
+
)
|
32 |
|
33 |
def forward(self, x):
|
34 |
module_device = x.device
|
35 |
downcast_x = _cast_if_autocast_enabled(x)
|
36 |
+
downcast_weight = (
|
37 |
+
_cast_if_autocast_enabled(self.weight)
|
38 |
+
if self.weight is not None
|
39 |
+
else self.weight
|
40 |
+
)
|
41 |
+
downcast_bias = (
|
42 |
+
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
43 |
+
)
|
44 |
with torch.autocast(enabled=False, device_type=module_device.type):
|
45 |
+
return torch.nn.functional.layer_norm(
|
46 |
+
downcast_x,
|
47 |
+
self.normalized_shape,
|
48 |
+
downcast_weight,
|
49 |
+
downcast_bias,
|
50 |
+
self.eps,
|
51 |
+
)
|
52 |
+
|
53 |
|
54 |
def rms_norm(x, weight=None, eps=1e-05):
|
55 |
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
|
|
57 |
return output * weight
|
58 |
return output
|
59 |
|
|
|
60 |
|
61 |
+
class RMSNorm(torch.nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
64 |
+
):
|
65 |
super().__init__()
|
66 |
self.eps = eps
|
67 |
if weight:
|
68 |
+
self.weight = torch.nn.Parameter(
|
69 |
+
torch.ones(normalized_shape, dtype=dtype, device=device)
|
70 |
+
)
|
71 |
else:
|
72 |
+
self.register_parameter("weight", None)
|
73 |
|
74 |
def forward(self, x):
|
75 |
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
76 |
|
|
|
77 |
|
78 |
+
class LPRMSNorm(RMSNorm):
|
79 |
+
def __init__(
|
80 |
+
self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
|
81 |
+
):
|
82 |
+
super().__init__(
|
83 |
+
normalized_shape=normalized_shape,
|
84 |
+
eps=eps,
|
85 |
+
weight=weight,
|
86 |
+
dtype=dtype,
|
87 |
+
device=device,
|
88 |
+
)
|
89 |
|
90 |
def forward(self, x):
|
91 |
downcast_x = _cast_if_autocast_enabled(x)
|
92 |
+
downcast_weight = (
|
93 |
+
_cast_if_autocast_enabled(self.weight)
|
94 |
+
if self.weight is not None
|
95 |
+
else self.weight
|
96 |
+
)
|
97 |
with torch.autocast(enabled=False, device_type=x.device.type):
|
98 |
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
99 |
+
|
100 |
+
|
101 |
+
NORM_CLASS_REGISTRY = {
|
102 |
+
"layernorm": torch.nn.LayerNorm,
|
103 |
+
"low_precision_layernorm": LPLayerNorm,
|
104 |
+
"rmsnorm": RMSNorm,
|
105 |
+
"low_precision_rmsnorm": LPRMSNorm,
|
106 |
+
}
|
model/llava/model/mpt/param_init_fns.py
CHANGED
@@ -3,101 +3,139 @@ import warnings
|
|
3 |
from collections.abc import Sequence
|
4 |
from functools import partial
|
5 |
from typing import Optional, Tuple, Union
|
|
|
6 |
import torch
|
7 |
from torch import nn
|
|
|
8 |
from .norm import NORM_CLASS_REGISTRY
|
9 |
|
10 |
-
|
|
|
11 |
del kwargs
|
12 |
if verbose > 1:
|
13 |
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
14 |
-
if hasattr(module,
|
15 |
module.reset_parameters()
|
16 |
|
|
|
17 |
def fused_init_helper_(module: nn.Module, init_fn_):
|
18 |
-
_fused = getattr(module,
|
19 |
if _fused is None:
|
20 |
-
raise RuntimeError(f
|
21 |
(dim, splits) = _fused
|
22 |
splits = (0, *splits, module.weight.size(dim))
|
23 |
-
for
|
24 |
slice_indices = [slice(None)] * module.weight.ndim
|
25 |
slice_indices[dim] = slice(s, e)
|
26 |
init_fn_(module.weight[slice_indices])
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
del kwargs
|
30 |
if verbose > 1:
|
31 |
-
warnings.warn(f
|
32 |
init_div_is_residual = init_div_is_residual
|
33 |
if init_div_is_residual is False:
|
34 |
div_is_residual = 1.0
|
35 |
elif init_div_is_residual is True:
|
36 |
div_is_residual = math.sqrt(2 * n_layers)
|
37 |
-
elif isinstance(init_div_is_residual, float) or isinstance(
|
|
|
|
|
38 |
div_is_residual = init_div_is_residual
|
39 |
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
40 |
div_is_residual = float(init_div_is_residual)
|
41 |
else:
|
42 |
div_is_residual = 1.0
|
43 |
-
raise ValueError(
|
|
|
|
|
44 |
if init_div_is_residual is not False:
|
45 |
if verbose > 1:
|
46 |
-
warnings.warn(
|
|
|
|
|
|
|
47 |
if isinstance(module, nn.Linear):
|
48 |
-
if hasattr(module,
|
49 |
fused_init_helper_(module, init_fn_)
|
50 |
else:
|
51 |
init_fn_(module.weight)
|
52 |
if module.bias is not None:
|
53 |
torch.nn.init.zeros_(module.bias)
|
54 |
-
if init_div_is_residual is not False and getattr(module,
|
55 |
with torch.no_grad():
|
56 |
module.weight.div_(div_is_residual)
|
57 |
elif isinstance(module, nn.Embedding):
|
58 |
if emb_init_std is not None:
|
59 |
std = emb_init_std
|
60 |
if std == 0:
|
61 |
-
warnings.warn(f
|
62 |
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
63 |
if verbose > 1:
|
64 |
-
warnings.warn(
|
|
|
|
|
65 |
elif emb_init_uniform_lim is not None:
|
66 |
lim = emb_init_uniform_lim
|
67 |
if isinstance(lim, Sequence):
|
68 |
if len(lim) > 2:
|
69 |
-
raise ValueError(
|
|
|
|
|
70 |
if lim[0] == lim[1]:
|
71 |
-
warnings.warn(f
|
72 |
else:
|
73 |
if lim == 0:
|
74 |
-
warnings.warn(f
|
75 |
lim = [-lim, lim]
|
76 |
(a, b) = lim
|
77 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
78 |
if verbose > 1:
|
79 |
-
warnings.warn(
|
|
|
|
|
80 |
else:
|
81 |
emb_init_fn_ = init_fn_
|
82 |
emb_init_fn_(module.weight)
|
83 |
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
84 |
if verbose > 1:
|
85 |
-
warnings.warn(
|
86 |
-
|
|
|
|
|
87 |
torch.nn.init.ones_(module.weight)
|
88 |
-
if hasattr(module,
|
89 |
torch.nn.init.zeros_(module.bias)
|
90 |
elif isinstance(module, nn.MultiheadAttention):
|
91 |
if module._qkv_same_embed_dim:
|
92 |
assert module.in_proj_weight is not None
|
93 |
-
assert
|
|
|
|
|
|
|
|
|
94 |
assert d_model is not None
|
95 |
_d = d_model
|
96 |
splits = (0, _d, 2 * _d, 3 * _d)
|
97 |
-
for
|
98 |
init_fn_(module.in_proj_weight[s:e])
|
99 |
else:
|
100 |
-
assert
|
|
|
|
|
|
|
|
|
101 |
assert module.in_proj_weight is None
|
102 |
init_fn_(module.q_proj_weight)
|
103 |
init_fn_(module.k_proj_weight)
|
@@ -109,37 +147,112 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
|
|
109 |
if module.bias_v is not None:
|
110 |
torch.nn.init.zeros_(module.bias_v)
|
111 |
init_fn_(module.out_proj.weight)
|
112 |
-
if init_div_is_residual is not False and getattr(
|
|
|
|
|
113 |
with torch.no_grad():
|
114 |
module.out_proj.weight.div_(div_is_residual)
|
115 |
if module.out_proj.bias is not None:
|
116 |
torch.nn.init.zeros_(module.out_proj.bias)
|
117 |
else:
|
118 |
for _ in module.parameters(recurse=False):
|
119 |
-
raise NotImplementedError(
|
|
|
|
|
|
|
120 |
|
121 |
def _normal_init_(std, mean=0.0):
|
122 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
123 |
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
del kwargs
|
126 |
init_fn_ = _normal_init_(std=std)
|
127 |
if verbose > 1:
|
128 |
-
warnings.warn(f
|
129 |
-
generic_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
def baseline_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
del kwargs
|
133 |
if init_std is None:
|
134 |
-
raise ValueError(
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
del kwargs
|
139 |
std = math.sqrt(2 / (5 * d_model))
|
140 |
-
_normal_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
def neox_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
144 |
|
145 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
@@ -149,33 +262,158 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
|
|
149 |
del kwargs
|
150 |
residual_div = n_layers / math.sqrt(10)
|
151 |
if verbose > 1:
|
152 |
-
warnings.warn(f
|
153 |
-
small_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
del kwargs
|
157 |
if verbose > 1:
|
158 |
-
warnings.warn(
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
def kaiming_normal_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
del kwargs
|
164 |
if verbose > 1:
|
165 |
-
warnings.warn(
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
del kwargs
|
171 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
172 |
if verbose > 1:
|
173 |
-
warnings.warn(
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
def xavier_normal_param_init_fn_(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
178 |
if verbose > 1:
|
179 |
-
warnings.warn(
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from collections.abc import Sequence
|
4 |
from functools import partial
|
5 |
from typing import Optional, Tuple, Union
|
6 |
+
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
+
|
10 |
from .norm import NORM_CLASS_REGISTRY
|
11 |
|
12 |
+
|
13 |
+
def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
|
14 |
del kwargs
|
15 |
if verbose > 1:
|
16 |
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
17 |
+
if hasattr(module, "reset_parameters"):
|
18 |
module.reset_parameters()
|
19 |
|
20 |
+
|
21 |
def fused_init_helper_(module: nn.Module, init_fn_):
|
22 |
+
_fused = getattr(module, "_fused", None)
|
23 |
if _fused is None:
|
24 |
+
raise RuntimeError(f"Internal logic error")
|
25 |
(dim, splits) = _fused
|
26 |
splits = (0, *splits, module.weight.size(dim))
|
27 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
28 |
slice_indices = [slice(None)] * module.weight.ndim
|
29 |
slice_indices[dim] = slice(s, e)
|
30 |
init_fn_(module.weight[slice_indices])
|
31 |
|
32 |
+
|
33 |
+
def generic_param_init_fn_(
|
34 |
+
module: nn.Module,
|
35 |
+
init_fn_,
|
36 |
+
n_layers: int,
|
37 |
+
d_model: Optional[int] = None,
|
38 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
39 |
+
emb_init_std: Optional[float] = None,
|
40 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
41 |
+
verbose: int = 0,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
del kwargs
|
45 |
if verbose > 1:
|
46 |
+
warnings.warn(f"If model has bias parameters they are initialized to 0.")
|
47 |
init_div_is_residual = init_div_is_residual
|
48 |
if init_div_is_residual is False:
|
49 |
div_is_residual = 1.0
|
50 |
elif init_div_is_residual is True:
|
51 |
div_is_residual = math.sqrt(2 * n_layers)
|
52 |
+
elif isinstance(init_div_is_residual, float) or isinstance(
|
53 |
+
init_div_is_residual, int
|
54 |
+
):
|
55 |
div_is_residual = init_div_is_residual
|
56 |
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
57 |
div_is_residual = float(init_div_is_residual)
|
58 |
else:
|
59 |
div_is_residual = 1.0
|
60 |
+
raise ValueError(
|
61 |
+
f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
|
62 |
+
)
|
63 |
if init_div_is_residual is not False:
|
64 |
if verbose > 1:
|
65 |
+
warnings.warn(
|
66 |
+
f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
|
67 |
+
+ f"Set `init_div_is_residual: false` in init config to disable this."
|
68 |
+
)
|
69 |
if isinstance(module, nn.Linear):
|
70 |
+
if hasattr(module, "_fused"):
|
71 |
fused_init_helper_(module, init_fn_)
|
72 |
else:
|
73 |
init_fn_(module.weight)
|
74 |
if module.bias is not None:
|
75 |
torch.nn.init.zeros_(module.bias)
|
76 |
+
if init_div_is_residual is not False and getattr(module, "_is_residual", False):
|
77 |
with torch.no_grad():
|
78 |
module.weight.div_(div_is_residual)
|
79 |
elif isinstance(module, nn.Embedding):
|
80 |
if emb_init_std is not None:
|
81 |
std = emb_init_std
|
82 |
if std == 0:
|
83 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
84 |
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
85 |
if verbose > 1:
|
86 |
+
warnings.warn(
|
87 |
+
f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}."
|
88 |
+
)
|
89 |
elif emb_init_uniform_lim is not None:
|
90 |
lim = emb_init_uniform_lim
|
91 |
if isinstance(lim, Sequence):
|
92 |
if len(lim) > 2:
|
93 |
+
raise ValueError(
|
94 |
+
f"Uniform init requires a min and a max limit. User input: {lim}."
|
95 |
+
)
|
96 |
if lim[0] == lim[1]:
|
97 |
+
warnings.warn(f"Embedding layer initialized to {lim[0]}.")
|
98 |
else:
|
99 |
if lim == 0:
|
100 |
+
warnings.warn(f"Embedding layer initialized to 0.")
|
101 |
lim = [-lim, lim]
|
102 |
(a, b) = lim
|
103 |
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
104 |
if verbose > 1:
|
105 |
+
warnings.warn(
|
106 |
+
f"Embedding layer initialized using uniform distribution in range {lim}."
|
107 |
+
)
|
108 |
else:
|
109 |
emb_init_fn_ = init_fn_
|
110 |
emb_init_fn_(module.weight)
|
111 |
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
112 |
if verbose > 1:
|
113 |
+
warnings.warn(
|
114 |
+
f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0."
|
115 |
+
)
|
116 |
+
if hasattr(module, "weight") and module.weight is not None:
|
117 |
torch.nn.init.ones_(module.weight)
|
118 |
+
if hasattr(module, "bias") and module.bias is not None:
|
119 |
torch.nn.init.zeros_(module.bias)
|
120 |
elif isinstance(module, nn.MultiheadAttention):
|
121 |
if module._qkv_same_embed_dim:
|
122 |
assert module.in_proj_weight is not None
|
123 |
+
assert (
|
124 |
+
module.q_proj_weight is None
|
125 |
+
and module.k_proj_weight is None
|
126 |
+
and (module.v_proj_weight is None)
|
127 |
+
)
|
128 |
assert d_model is not None
|
129 |
_d = d_model
|
130 |
splits = (0, _d, 2 * _d, 3 * _d)
|
131 |
+
for s, e in zip(splits[:-1], splits[1:]):
|
132 |
init_fn_(module.in_proj_weight[s:e])
|
133 |
else:
|
134 |
+
assert (
|
135 |
+
module.q_proj_weight is not None
|
136 |
+
and module.k_proj_weight is not None
|
137 |
+
and (module.v_proj_weight is not None)
|
138 |
+
)
|
139 |
assert module.in_proj_weight is None
|
140 |
init_fn_(module.q_proj_weight)
|
141 |
init_fn_(module.k_proj_weight)
|
|
|
147 |
if module.bias_v is not None:
|
148 |
torch.nn.init.zeros_(module.bias_v)
|
149 |
init_fn_(module.out_proj.weight)
|
150 |
+
if init_div_is_residual is not False and getattr(
|
151 |
+
module.out_proj, "_is_residual", False
|
152 |
+
):
|
153 |
with torch.no_grad():
|
154 |
module.out_proj.weight.div_(div_is_residual)
|
155 |
if module.out_proj.bias is not None:
|
156 |
torch.nn.init.zeros_(module.out_proj.bias)
|
157 |
else:
|
158 |
for _ in module.parameters(recurse=False):
|
159 |
+
raise NotImplementedError(
|
160 |
+
f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
|
161 |
+
)
|
162 |
+
|
163 |
|
164 |
def _normal_init_(std, mean=0.0):
|
165 |
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
166 |
|
167 |
+
|
168 |
+
def _normal_param_init_fn_(
|
169 |
+
module: nn.Module,
|
170 |
+
std: float,
|
171 |
+
n_layers: int,
|
172 |
+
d_model: Optional[int] = None,
|
173 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
174 |
+
emb_init_std: Optional[float] = None,
|
175 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
176 |
+
verbose: int = 0,
|
177 |
+
**kwargs,
|
178 |
+
):
|
179 |
del kwargs
|
180 |
init_fn_ = _normal_init_(std=std)
|
181 |
if verbose > 1:
|
182 |
+
warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
|
183 |
+
generic_param_init_fn_(
|
184 |
+
module=module,
|
185 |
+
init_fn_=init_fn_,
|
186 |
+
d_model=d_model,
|
187 |
+
n_layers=n_layers,
|
188 |
+
init_div_is_residual=init_div_is_residual,
|
189 |
+
emb_init_std=emb_init_std,
|
190 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
191 |
+
verbose=verbose,
|
192 |
+
)
|
193 |
+
|
194 |
|
195 |
+
def baseline_param_init_fn_(
|
196 |
+
module: nn.Module,
|
197 |
+
init_std: float,
|
198 |
+
n_layers: int,
|
199 |
+
d_model: Optional[int] = None,
|
200 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
201 |
+
emb_init_std: Optional[float] = None,
|
202 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
203 |
+
verbose: int = 0,
|
204 |
+
**kwargs,
|
205 |
+
):
|
206 |
del kwargs
|
207 |
if init_std is None:
|
208 |
+
raise ValueError(
|
209 |
+
"You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
|
210 |
+
)
|
211 |
+
_normal_param_init_fn_(
|
212 |
+
module=module,
|
213 |
+
std=init_std,
|
214 |
+
d_model=d_model,
|
215 |
+
n_layers=n_layers,
|
216 |
+
init_div_is_residual=init_div_is_residual,
|
217 |
+
emb_init_std=emb_init_std,
|
218 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
219 |
+
verbose=verbose,
|
220 |
+
)
|
221 |
|
222 |
+
|
223 |
+
def small_param_init_fn_(
|
224 |
+
module: nn.Module,
|
225 |
+
n_layers: int,
|
226 |
+
d_model: int,
|
227 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
228 |
+
emb_init_std: Optional[float] = None,
|
229 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
230 |
+
verbose: int = 0,
|
231 |
+
**kwargs,
|
232 |
+
):
|
233 |
del kwargs
|
234 |
std = math.sqrt(2 / (5 * d_model))
|
235 |
+
_normal_param_init_fn_(
|
236 |
+
module=module,
|
237 |
+
std=std,
|
238 |
+
d_model=d_model,
|
239 |
+
n_layers=n_layers,
|
240 |
+
init_div_is_residual=init_div_is_residual,
|
241 |
+
emb_init_std=emb_init_std,
|
242 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
243 |
+
verbose=verbose,
|
244 |
+
)
|
245 |
+
|
246 |
|
247 |
+
def neox_param_init_fn_(
|
248 |
+
module: nn.Module,
|
249 |
+
n_layers: int,
|
250 |
+
d_model: int,
|
251 |
+
emb_init_std: Optional[float] = None,
|
252 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
253 |
+
verbose: int = 0,
|
254 |
+
**kwargs,
|
255 |
+
):
|
256 |
"""From section 2.3.1 of GPT-NeoX-20B:
|
257 |
|
258 |
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
|
|
262 |
del kwargs
|
263 |
residual_div = n_layers / math.sqrt(10)
|
264 |
if verbose > 1:
|
265 |
+
warnings.warn(f"setting init_div_is_residual to {residual_div}")
|
266 |
+
small_param_init_fn_(
|
267 |
+
module=module,
|
268 |
+
d_model=d_model,
|
269 |
+
n_layers=n_layers,
|
270 |
+
init_div_is_residual=residual_div,
|
271 |
+
emb_init_std=emb_init_std,
|
272 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
273 |
+
verbose=verbose,
|
274 |
+
)
|
275 |
|
276 |
+
|
277 |
+
def kaiming_uniform_param_init_fn_(
|
278 |
+
module: nn.Module,
|
279 |
+
n_layers: int,
|
280 |
+
d_model: Optional[int] = None,
|
281 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
282 |
+
emb_init_std: Optional[float] = None,
|
283 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
284 |
+
init_gain: float = 0,
|
285 |
+
fan_mode: str = "fan_in",
|
286 |
+
init_nonlinearity: str = "leaky_relu",
|
287 |
+
verbose: int = 0,
|
288 |
+
**kwargs,
|
289 |
+
):
|
290 |
del kwargs
|
291 |
if verbose > 1:
|
292 |
+
warnings.warn(
|
293 |
+
f"Using nn.init.kaiming_uniform_ init fn with parameters: "
|
294 |
+
+ f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
|
295 |
+
)
|
296 |
+
kaiming_uniform_ = partial(
|
297 |
+
nn.init.kaiming_uniform_,
|
298 |
+
a=init_gain,
|
299 |
+
mode=fan_mode,
|
300 |
+
nonlinearity=init_nonlinearity,
|
301 |
+
)
|
302 |
+
generic_param_init_fn_(
|
303 |
+
module=module,
|
304 |
+
init_fn_=kaiming_uniform_,
|
305 |
+
d_model=d_model,
|
306 |
+
n_layers=n_layers,
|
307 |
+
init_div_is_residual=init_div_is_residual,
|
308 |
+
emb_init_std=emb_init_std,
|
309 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
310 |
+
verbose=verbose,
|
311 |
+
)
|
312 |
+
|
313 |
|
314 |
+
def kaiming_normal_param_init_fn_(
|
315 |
+
module: nn.Module,
|
316 |
+
n_layers: int,
|
317 |
+
d_model: Optional[int] = None,
|
318 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
319 |
+
emb_init_std: Optional[float] = None,
|
320 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
321 |
+
init_gain: float = 0,
|
322 |
+
fan_mode: str = "fan_in",
|
323 |
+
init_nonlinearity: str = "leaky_relu",
|
324 |
+
verbose: int = 0,
|
325 |
+
**kwargs,
|
326 |
+
):
|
327 |
del kwargs
|
328 |
if verbose > 1:
|
329 |
+
warnings.warn(
|
330 |
+
f"Using nn.init.kaiming_normal_ init fn with parameters: "
|
331 |
+
+ f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
|
332 |
+
)
|
333 |
+
kaiming_normal_ = partial(
|
334 |
+
torch.nn.init.kaiming_normal_,
|
335 |
+
a=init_gain,
|
336 |
+
mode=fan_mode,
|
337 |
+
nonlinearity=init_nonlinearity,
|
338 |
+
)
|
339 |
+
generic_param_init_fn_(
|
340 |
+
module=module,
|
341 |
+
init_fn_=kaiming_normal_,
|
342 |
+
d_model=d_model,
|
343 |
+
n_layers=n_layers,
|
344 |
+
init_div_is_residual=init_div_is_residual,
|
345 |
+
emb_init_std=emb_init_std,
|
346 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
347 |
+
verbose=verbose,
|
348 |
+
)
|
349 |
|
350 |
+
|
351 |
+
def xavier_uniform_param_init_fn_(
|
352 |
+
module: nn.Module,
|
353 |
+
n_layers: int,
|
354 |
+
d_model: Optional[int] = None,
|
355 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
356 |
+
emb_init_std: Optional[float] = None,
|
357 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
358 |
+
init_gain: float = 0,
|
359 |
+
verbose: int = 0,
|
360 |
+
**kwargs,
|
361 |
+
):
|
362 |
del kwargs
|
363 |
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
364 |
if verbose > 1:
|
365 |
+
warnings.warn(
|
366 |
+
f"Using torch.nn.init.xavier_uniform_ init fn with parameters: "
|
367 |
+
+ f"gain={init_gain}"
|
368 |
+
)
|
369 |
+
generic_param_init_fn_(
|
370 |
+
module=module,
|
371 |
+
init_fn_=xavier_uniform_,
|
372 |
+
d_model=d_model,
|
373 |
+
n_layers=n_layers,
|
374 |
+
init_div_is_residual=init_div_is_residual,
|
375 |
+
emb_init_std=emb_init_std,
|
376 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
377 |
+
verbose=verbose,
|
378 |
+
)
|
379 |
+
|
380 |
|
381 |
+
def xavier_normal_param_init_fn_(
|
382 |
+
module: nn.Module,
|
383 |
+
n_layers: int,
|
384 |
+
d_model: Optional[int] = None,
|
385 |
+
init_div_is_residual: Union[int, float, str, bool] = True,
|
386 |
+
emb_init_std: Optional[float] = None,
|
387 |
+
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
388 |
+
init_gain: float = 0,
|
389 |
+
verbose: int = 0,
|
390 |
+
**kwargs,
|
391 |
+
):
|
392 |
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
393 |
if verbose > 1:
|
394 |
+
warnings.warn(
|
395 |
+
f"Using torch.nn.init.xavier_normal_ init fn with parameters: "
|
396 |
+
+ f"gain={init_gain}"
|
397 |
+
)
|
398 |
+
generic_param_init_fn_(
|
399 |
+
module=module,
|
400 |
+
init_fn_=xavier_normal_,
|
401 |
+
d_model=d_model,
|
402 |
+
n_layers=n_layers,
|
403 |
+
init_div_is_residual=init_div_is_residual,
|
404 |
+
emb_init_std=emb_init_std,
|
405 |
+
emb_init_uniform_lim=emb_init_uniform_lim,
|
406 |
+
verbose=verbose,
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
MODEL_INIT_REGISTRY = {
|
411 |
+
"default_": torch_default_param_init_fn_,
|
412 |
+
"baseline_": baseline_param_init_fn_,
|
413 |
+
"kaiming_uniform_": kaiming_uniform_param_init_fn_,
|
414 |
+
"kaiming_normal_": kaiming_normal_param_init_fn_,
|
415 |
+
"neox_init_": neox_param_init_fn_,
|
416 |
+
"small_init_": small_param_init_fn_,
|
417 |
+
"xavier_uniform_": xavier_uniform_param_init_fn_,
|
418 |
+
"xavier_normal_": xavier_normal_param_init_fn_,
|
419 |
+
}
|
model/llava/model/utils.py
CHANGED
@@ -5,16 +5,20 @@ from transformers import AutoConfig, StoppingCriteria
|
|
5 |
|
6 |
def auto_upgrade(config):
|
7 |
cfg = AutoConfig.from_pretrained(config)
|
8 |
-
if
|
9 |
-
assert cfg.model_type ==
|
10 |
-
print(
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
13 |
if confirm.lower() in ["y", "yes"]:
|
14 |
print("Upgrading checkpoint...")
|
15 |
assert len(cfg.architectures) == 1
|
16 |
setattr(cfg.__class__, "model_type", "llava")
|
17 |
-
cfg.architectures[0] =
|
18 |
cfg.save_pretrained(config)
|
19 |
print("Checkpoint upgraded.")
|
20 |
else:
|
@@ -22,24 +26,31 @@ def auto_upgrade(config):
|
|
22 |
exit(1)
|
23 |
|
24 |
|
25 |
-
|
26 |
class KeywordsStoppingCriteria(StoppingCriteria):
|
27 |
def __init__(self, keywords, tokenizer, input_ids):
|
28 |
self.keywords = keywords
|
29 |
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
|
30 |
-
self.keyword_ids = [
|
|
|
|
|
|
|
|
|
31 |
self.tokenizer = tokenizer
|
32 |
self.start_len = None
|
33 |
self.input_ids = input_ids
|
34 |
|
35 |
-
def __call__(
|
|
|
|
|
36 |
if self.start_len is None:
|
37 |
self.start_len = self.input_ids.shape[1]
|
38 |
else:
|
39 |
for keyword_id in self.keyword_ids:
|
40 |
if output_ids[0, -1] == keyword_id:
|
41 |
return True
|
42 |
-
outputs = self.tokenizer.batch_decode(
|
|
|
|
|
43 |
for keyword in self.keywords:
|
44 |
if keyword in outputs:
|
45 |
return True
|
@@ -50,7 +61,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
|
|
50 |
|
51 |
# # if output_ids[0, -1] == keyword_id:
|
52 |
# # return True
|
53 |
-
|
54 |
# print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len))
|
55 |
|
56 |
# print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:])
|
|
|
5 |
|
6 |
def auto_upgrade(config):
|
7 |
cfg = AutoConfig.from_pretrained(config)
|
8 |
+
if "llava" in config and "llava" not in cfg.model_type:
|
9 |
+
assert cfg.model_type == "llama"
|
10 |
+
print(
|
11 |
+
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
|
12 |
+
)
|
13 |
+
print(
|
14 |
+
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
|
15 |
+
)
|
16 |
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
17 |
if confirm.lower() in ["y", "yes"]:
|
18 |
print("Upgrading checkpoint...")
|
19 |
assert len(cfg.architectures) == 1
|
20 |
setattr(cfg.__class__, "model_type", "llava")
|
21 |
+
cfg.architectures[0] = "LlavaLlamaForCausalLM"
|
22 |
cfg.save_pretrained(config)
|
23 |
print("Checkpoint upgraded.")
|
24 |
else:
|
|
|
26 |
exit(1)
|
27 |
|
28 |
|
|
|
29 |
class KeywordsStoppingCriteria(StoppingCriteria):
|
30 |
def __init__(self, keywords, tokenizer, input_ids):
|
31 |
self.keywords = keywords
|
32 |
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
|
33 |
+
self.keyword_ids = [
|
34 |
+
keyword_id[0]
|
35 |
+
for keyword_id in self.keyword_ids
|
36 |
+
if type(keyword_id) is list and len(keyword_id) == 1
|
37 |
+
]
|
38 |
self.tokenizer = tokenizer
|
39 |
self.start_len = None
|
40 |
self.input_ids = input_ids
|
41 |
|
42 |
+
def __call__(
|
43 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
44 |
+
) -> bool:
|
45 |
if self.start_len is None:
|
46 |
self.start_len = self.input_ids.shape[1]
|
47 |
else:
|
48 |
for keyword_id in self.keyword_ids:
|
49 |
if output_ids[0, -1] == keyword_id:
|
50 |
return True
|
51 |
+
outputs = self.tokenizer.batch_decode(
|
52 |
+
output_ids[:, self.start_len :], skip_special_tokens=True
|
53 |
+
)[0]
|
54 |
for keyword in self.keywords:
|
55 |
if keyword in outputs:
|
56 |
return True
|
|
|
61 |
|
62 |
# # if output_ids[0, -1] == keyword_id:
|
63 |
# # return True
|
64 |
+
|
65 |
# print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len))
|
66 |
|
67 |
# print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:])
|
model/llava/serve/cli.py
CHANGED
@@ -6,14 +6,14 @@ import argparse
|
|
6 |
import time
|
7 |
|
8 |
import torch
|
9 |
-
from
|
10 |
-
|
11 |
-
from llava.conversation import conv_templates, SeparatorStyle
|
12 |
|
13 |
|
14 |
@torch.inference_mode()
|
15 |
-
def generate_stream(
|
16 |
-
|
|
|
17 |
"""Adapted from fastchat/serve/model_worker.py::generate_stream"""
|
18 |
|
19 |
prompt = params["prompt"]
|
@@ -30,17 +30,19 @@ def generate_stream(tokenizer, model, params, device,
|
|
30 |
|
31 |
for i in range(max_new_tokens):
|
32 |
if i == 0:
|
33 |
-
out = model(
|
34 |
-
torch.as_tensor([input_ids], device=device), use_cache=True)
|
35 |
logits = out.logits
|
36 |
past_key_values = out.past_key_values
|
37 |
else:
|
38 |
attention_mask = torch.ones(
|
39 |
-
1, past_key_values[0][0].shape[-2] + 1, device=device
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
logits = out.logits
|
45 |
past_key_values = out.past_key_values
|
46 |
|
@@ -84,18 +86,21 @@ def main(args):
|
|
84 |
else:
|
85 |
num_gpus = int(num_gpus)
|
86 |
if num_gpus != 1:
|
87 |
-
kwargs.update(
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
elif args.device == "cpu":
|
92 |
kwargs = {}
|
93 |
else:
|
94 |
raise ValueError(f"Invalid device: {args.device}")
|
95 |
|
96 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
97 |
-
model = AutoModelForCausalLM.from_pretrained(
|
98 |
-
low_cpu_mem_usage=True, **kwargs
|
|
|
99 |
|
100 |
if args.device == "cuda" and num_gpus == 1:
|
101 |
model.cuda()
|
@@ -126,11 +131,11 @@ def main(args):
|
|
126 |
print(f"{conv.roles[1]}: ", end="", flush=True)
|
127 |
pre = 0
|
128 |
for outputs in generate_stream(tokenizer, model, params, args.device):
|
129 |
-
outputs = outputs[len(prompt) + 1:].strip()
|
130 |
outputs = outputs.split(" ")
|
131 |
now = len(outputs)
|
132 |
if now - 1 > pre:
|
133 |
-
print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
|
134 |
pre = now - 1
|
135 |
print(" ".join(outputs[pre:]), flush=True)
|
136 |
|
|
|
6 |
import time
|
7 |
|
8 |
import torch
|
9 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
11 |
|
12 |
|
13 |
@torch.inference_mode()
|
14 |
+
def generate_stream(
|
15 |
+
tokenizer, model, params, device, context_len=2048, stream_interval=2
|
16 |
+
):
|
17 |
"""Adapted from fastchat/serve/model_worker.py::generate_stream"""
|
18 |
|
19 |
prompt = params["prompt"]
|
|
|
30 |
|
31 |
for i in range(max_new_tokens):
|
32 |
if i == 0:
|
33 |
+
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
|
|
34 |
logits = out.logits
|
35 |
past_key_values = out.past_key_values
|
36 |
else:
|
37 |
attention_mask = torch.ones(
|
38 |
+
1, past_key_values[0][0].shape[-2] + 1, device=device
|
39 |
+
)
|
40 |
+
out = model(
|
41 |
+
input_ids=torch.as_tensor([[token]], device=device),
|
42 |
+
use_cache=True,
|
43 |
+
attention_mask=attention_mask,
|
44 |
+
past_key_values=past_key_values,
|
45 |
+
)
|
46 |
logits = out.logits
|
47 |
past_key_values = out.past_key_values
|
48 |
|
|
|
86 |
else:
|
87 |
num_gpus = int(num_gpus)
|
88 |
if num_gpus != 1:
|
89 |
+
kwargs.update(
|
90 |
+
{
|
91 |
+
"device_map": "auto",
|
92 |
+
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
93 |
+
}
|
94 |
+
)
|
95 |
elif args.device == "cpu":
|
96 |
kwargs = {}
|
97 |
else:
|
98 |
raise ValueError(f"Invalid device: {args.device}")
|
99 |
|
100 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
101 |
+
model = AutoModelForCausalLM.from_pretrained(
|
102 |
+
model_name, low_cpu_mem_usage=True, **kwargs
|
103 |
+
)
|
104 |
|
105 |
if args.device == "cuda" and num_gpus == 1:
|
106 |
model.cuda()
|
|
|
131 |
print(f"{conv.roles[1]}: ", end="", flush=True)
|
132 |
pre = 0
|
133 |
for outputs in generate_stream(tokenizer, model, params, args.device):
|
134 |
+
outputs = outputs[len(prompt) + 1 :].strip()
|
135 |
outputs = outputs.split(" ")
|
136 |
now = len(outputs)
|
137 |
if now - 1 > pre:
|
138 |
+
print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
|
139 |
pre = now - 1
|
140 |
print(" ".join(outputs[pre:]), flush=True)
|
141 |
|
model/llava/serve/controller.py
CHANGED
@@ -5,23 +5,21 @@ It sends worker addresses to clients.
|
|
5 |
import argparse
|
6 |
import asyncio
|
7 |
import dataclasses
|
8 |
-
from enum import Enum, auto
|
9 |
import json
|
10 |
import logging
|
|
|
11 |
import time
|
|
|
12 |
from typing import List, Union
|
13 |
-
import threading
|
14 |
|
15 |
-
from fastapi import FastAPI, Request
|
16 |
-
from fastapi.responses import StreamingResponse
|
17 |
import numpy as np
|
18 |
import requests
|
19 |
import uvicorn
|
20 |
-
|
|
|
21 |
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
from llava.utils import build_logger, server_error_msg
|
23 |
|
24 |
-
|
25 |
logger = build_logger("controller", "controller.log")
|
26 |
|
27 |
|
@@ -61,13 +59,15 @@ class Controller:
|
|
61 |
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
|
63 |
self.heart_beat_thread = threading.Thread(
|
64 |
-
target=heart_beat_controller, args=(self,)
|
|
|
65 |
self.heart_beat_thread.start()
|
66 |
|
67 |
logger.info("Init controller")
|
68 |
|
69 |
-
def register_worker(
|
70 |
-
|
|
|
71 |
if worker_name not in self.worker_info:
|
72 |
logger.info(f"Register a new worker: {worker_name}")
|
73 |
else:
|
@@ -79,8 +79,12 @@ class Controller:
|
|
79 |
return False
|
80 |
|
81 |
self.worker_info[worker_name] = WorkerInfo(
|
82 |
-
worker_status["model_names"],
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
85 |
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
return True
|
@@ -131,15 +135,13 @@ class Controller:
|
|
131 |
return ""
|
132 |
worker_speeds = worker_speeds / norm
|
133 |
if True: # Directly return address
|
134 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
-
p=worker_speeds)
|
136 |
worker_name = worker_names[pt]
|
137 |
return worker_name
|
138 |
|
139 |
# Check status before returning
|
140 |
while True:
|
141 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
-
p=worker_speeds)
|
143 |
worker_name = worker_names[pt]
|
144 |
|
145 |
if self.get_worker_status(worker_name):
|
@@ -165,7 +167,9 @@ class Controller:
|
|
165 |
min_index = np.argmin(worker_qlen)
|
166 |
w_name = worker_names[min_index]
|
167 |
self.worker_info[w_name].queue_length += 1
|
168 |
-
logger.info(
|
|
|
|
|
169 |
return w_name
|
170 |
else:
|
171 |
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
@@ -201,8 +205,12 @@ class Controller:
|
|
201 |
yield json.dumps(ret).encode() + b"\0"
|
202 |
|
203 |
try:
|
204 |
-
response = requests.post(
|
205 |
-
|
|
|
|
|
|
|
|
|
206 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
if chunk:
|
208 |
yield chunk + b"\0"
|
@@ -214,7 +222,6 @@ class Controller:
|
|
214 |
}
|
215 |
yield json.dumps(ret).encode() + b"\0"
|
216 |
|
217 |
-
|
218 |
# Let the controller act as a worker to achieve hierarchical
|
219 |
# management. This can be used to connect isolated sub networks.
|
220 |
def worker_api_get_status(self):
|
@@ -243,8 +250,8 @@ app = FastAPI()
|
|
243 |
async def register_worker(request: Request):
|
244 |
data = await request.json()
|
245 |
controller.register_worker(
|
246 |
-
data["worker_name"], data["check_heart_beat"],
|
247 |
-
|
248 |
|
249 |
|
250 |
@app.post("/refresh_all_workers")
|
@@ -268,8 +275,7 @@ async def get_worker_address(request: Request):
|
|
268 |
@app.post("/receive_heart_beat")
|
269 |
async def receive_heart_beat(request: Request):
|
270 |
data = await request.json()
|
271 |
-
exist = controller.receive_heart_beat(
|
272 |
-
data["worker_name"], data["queue_length"])
|
273 |
return {"exist": exist}
|
274 |
|
275 |
|
@@ -289,8 +295,12 @@ if __name__ == "__main__":
|
|
289 |
parser = argparse.ArgumentParser()
|
290 |
parser.add_argument("--host", type=str, default="localhost")
|
291 |
parser.add_argument("--port", type=int, default=21001)
|
292 |
-
parser.add_argument(
|
293 |
-
"
|
|
|
|
|
|
|
|
|
294 |
args = parser.parse_args()
|
295 |
logger.info(f"args: {args}")
|
296 |
|
|
|
5 |
import argparse
|
6 |
import asyncio
|
7 |
import dataclasses
|
|
|
8 |
import json
|
9 |
import logging
|
10 |
+
import threading
|
11 |
import time
|
12 |
+
from enum import Enum, auto
|
13 |
from typing import List, Union
|
|
|
14 |
|
|
|
|
|
15 |
import numpy as np
|
16 |
import requests
|
17 |
import uvicorn
|
18 |
+
from fastapi import FastAPI, Request
|
19 |
+
from fastapi.responses import StreamingResponse
|
20 |
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
21 |
from llava.utils import build_logger, server_error_msg
|
22 |
|
|
|
23 |
logger = build_logger("controller", "controller.log")
|
24 |
|
25 |
|
|
|
59 |
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
60 |
|
61 |
self.heart_beat_thread = threading.Thread(
|
62 |
+
target=heart_beat_controller, args=(self,)
|
63 |
+
)
|
64 |
self.heart_beat_thread.start()
|
65 |
|
66 |
logger.info("Init controller")
|
67 |
|
68 |
+
def register_worker(
|
69 |
+
self, worker_name: str, check_heart_beat: bool, worker_status: dict
|
70 |
+
):
|
71 |
if worker_name not in self.worker_info:
|
72 |
logger.info(f"Register a new worker: {worker_name}")
|
73 |
else:
|
|
|
79 |
return False
|
80 |
|
81 |
self.worker_info[worker_name] = WorkerInfo(
|
82 |
+
worker_status["model_names"],
|
83 |
+
worker_status["speed"],
|
84 |
+
worker_status["queue_length"],
|
85 |
+
check_heart_beat,
|
86 |
+
time.time(),
|
87 |
+
)
|
88 |
|
89 |
logger.info(f"Register done: {worker_name}, {worker_status}")
|
90 |
return True
|
|
|
135 |
return ""
|
136 |
worker_speeds = worker_speeds / norm
|
137 |
if True: # Directly return address
|
138 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
|
|
139 |
worker_name = worker_names[pt]
|
140 |
return worker_name
|
141 |
|
142 |
# Check status before returning
|
143 |
while True:
|
144 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
|
|
145 |
worker_name = worker_names[pt]
|
146 |
|
147 |
if self.get_worker_status(worker_name):
|
|
|
167 |
min_index = np.argmin(worker_qlen)
|
168 |
w_name = worker_names[min_index]
|
169 |
self.worker_info[w_name].queue_length += 1
|
170 |
+
logger.info(
|
171 |
+
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
|
172 |
+
)
|
173 |
return w_name
|
174 |
else:
|
175 |
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
|
|
205 |
yield json.dumps(ret).encode() + b"\0"
|
206 |
|
207 |
try:
|
208 |
+
response = requests.post(
|
209 |
+
worker_addr + "/worker_generate_stream",
|
210 |
+
json=params,
|
211 |
+
stream=True,
|
212 |
+
timeout=5,
|
213 |
+
)
|
214 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
215 |
if chunk:
|
216 |
yield chunk + b"\0"
|
|
|
222 |
}
|
223 |
yield json.dumps(ret).encode() + b"\0"
|
224 |
|
|
|
225 |
# Let the controller act as a worker to achieve hierarchical
|
226 |
# management. This can be used to connect isolated sub networks.
|
227 |
def worker_api_get_status(self):
|
|
|
250 |
async def register_worker(request: Request):
|
251 |
data = await request.json()
|
252 |
controller.register_worker(
|
253 |
+
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
|
254 |
+
)
|
255 |
|
256 |
|
257 |
@app.post("/refresh_all_workers")
|
|
|
275 |
@app.post("/receive_heart_beat")
|
276 |
async def receive_heart_beat(request: Request):
|
277 |
data = await request.json()
|
278 |
+
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
|
|
|
279 |
return {"exist": exist}
|
280 |
|
281 |
|
|
|
295 |
parser = argparse.ArgumentParser()
|
296 |
parser.add_argument("--host", type=str, default="localhost")
|
297 |
parser.add_argument("--port", type=int, default=21001)
|
298 |
+
parser.add_argument(
|
299 |
+
"--dispatch-method",
|
300 |
+
type=str,
|
301 |
+
choices=["lottery", "shortest_queue"],
|
302 |
+
default="shortest_queue",
|
303 |
+
)
|
304 |
args = parser.parse_args()
|
305 |
logger.info(f"args: {args}")
|
306 |
|
model/llava/serve/gradio_css.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
code_highlight_css =
|
2 |
-
"""
|
3 |
#chatbot .hll { background-color: #ffffcc }
|
4 |
#chatbot .c { color: #408080; font-style: italic }
|
5 |
#chatbot .err { border: 1px solid #FF0000 }
|
@@ -68,6 +67,5 @@ code_highlight_css = (
|
|
68 |
#chatbot .vi { color: #19177C }
|
69 |
#chatbot .vm { color: #19177C }
|
70 |
#chatbot .il { color: #666666 }
|
71 |
-
"""
|
72 |
-
|
73 |
-
|
|
|
1 |
+
code_highlight_css = """
|
|
|
2 |
#chatbot .hll { background-color: #ffffcc }
|
3 |
#chatbot .c { color: #408080; font-style: italic }
|
4 |
#chatbot .err { border: 1px solid #FF0000 }
|
|
|
67 |
#chatbot .vi { color: #19177C }
|
68 |
#chatbot .vm { color: #19177C }
|
69 |
#chatbot .il { color: #666666 }
|
70 |
+
"""
|
71 |
+
# .highlight { background: #f8f8f8; }
|
|
model/llava/serve/gradio_patch.py
CHANGED
@@ -50,7 +50,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
|
50 |
warnings.warn(
|
51 |
"The 'color_map' parameter has been deprecated.",
|
52 |
)
|
53 |
-
#self.md = utils.get_markdown_parser()
|
54 |
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
55 |
self.select: EventListenerMethod
|
56 |
"""
|
@@ -113,7 +113,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
|
113 |
): # This happens for previously processed messages
|
114 |
return chat_message
|
115 |
elif isinstance(chat_message, str):
|
116 |
-
#return self.md.render(chat_message)
|
117 |
return str(self.md.convert(chat_message))
|
118 |
else:
|
119 |
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
@@ -142,9 +142,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
|
142 |
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
143 |
processed_messages.append(
|
144 |
(
|
145 |
-
#self._process_chat_messages(message_pair[0]),
|
146 |
-
'<pre style="font-family: var(--font)">'
|
147 |
-
message_pair[0]
|
|
|
148 |
self._process_chat_messages(message_pair[1]),
|
149 |
)
|
150 |
)
|
@@ -164,5 +165,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
|
164 |
**kwargs,
|
165 |
)
|
166 |
return self
|
167 |
-
|
168 |
-
|
|
|
50 |
warnings.warn(
|
51 |
"The 'color_map' parameter has been deprecated.",
|
52 |
)
|
53 |
+
# self.md = utils.get_markdown_parser()
|
54 |
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
|
55 |
self.select: EventListenerMethod
|
56 |
"""
|
|
|
113 |
): # This happens for previously processed messages
|
114 |
return chat_message
|
115 |
elif isinstance(chat_message, str):
|
116 |
+
# return self.md.render(chat_message)
|
117 |
return str(self.md.convert(chat_message))
|
118 |
else:
|
119 |
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
|
|
142 |
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
143 |
processed_messages.append(
|
144 |
(
|
145 |
+
# self._process_chat_messages(message_pair[0]),
|
146 |
+
'<pre style="font-family: var(--font)">'
|
147 |
+
+ message_pair[0]
|
148 |
+
+ "</pre>",
|
149 |
self._process_chat_messages(message_pair[1]),
|
150 |
)
|
151 |
)
|
|
|
165 |
**kwargs,
|
166 |
)
|
167 |
return self
|
|
|
|
model/llava/serve/gradio_web_server.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1 |
import argparse
|
2 |
-
from collections import defaultdict
|
3 |
import datetime
|
|
|
4 |
import json
|
5 |
import os
|
6 |
import time
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import requests
|
10 |
-
|
11 |
-
from llava.conversation import (default_conversation, conv_templates,
|
12 |
-
SeparatorStyle)
|
13 |
from llava.constants import LOGDIR
|
14 |
-
from llava.
|
15 |
-
|
16 |
-
from llava.serve.gradio_patch import Chatbot as grChatbot
|
17 |
from llava.serve.gradio_css import code_highlight_css
|
18 |
-
import
|
19 |
-
|
|
|
20 |
|
21 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
22 |
|
@@ -65,31 +63,33 @@ def load_demo(url_params, request: gr.Request):
|
|
65 |
if "model" in url_params:
|
66 |
model = url_params["model"]
|
67 |
if model in models:
|
68 |
-
dropdown_update = gr.Dropdown.update(
|
69 |
-
value=model, visible=True)
|
70 |
|
71 |
state = default_conversation.copy()
|
72 |
-
return (
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
|
80 |
|
81 |
def load_demo_refresh_model_list(request: gr.Request):
|
82 |
logger.info(f"load_demo. ip: {request.client.host}")
|
83 |
models = get_model_list()
|
84 |
state = default_conversation.copy()
|
85 |
-
return (
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
93 |
|
94 |
|
95 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
@@ -148,13 +148,14 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
148 |
if flagged:
|
149 |
state.skip_next = True
|
150 |
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
151 |
-
no_change_btn,
|
|
|
152 |
|
153 |
text = text[:1536] # Hard cut-off
|
154 |
if image is not None:
|
155 |
text = text[:1200] # Hard cut-off for images
|
156 |
-
if
|
157 |
-
text = text +
|
158 |
text = (text, image, image_process_mode)
|
159 |
state = default_conversation.copy()
|
160 |
state.append_message(state.roles[0], text)
|
@@ -195,9 +196,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|
195 |
template_name = "multimodal"
|
196 |
elif "mpt" in model_name:
|
197 |
template_name = "mpt_text"
|
198 |
-
elif "koala" in model_name:
|
199 |
template_name = "bair_v1"
|
200 |
-
elif "v1" in model_name:
|
201 |
template_name = "vicuna_v1_1"
|
202 |
else:
|
203 |
template_name = "v1"
|
@@ -208,15 +209,24 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|
208 |
|
209 |
# Query worker address
|
210 |
controller_url = args.controller_url
|
211 |
-
ret = requests.post(
|
212 |
-
|
|
|
213 |
worker_addr = ret.json()["address"]
|
214 |
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
215 |
|
216 |
# No available worker
|
217 |
if worker_addr == "":
|
218 |
state.messages[-1][-1] = server_error_msg
|
219 |
-
yield (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
return
|
221 |
|
222 |
# Construct prompt
|
@@ -226,7 +236,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|
226 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
227 |
for image, hash in zip(all_images, all_image_hash):
|
228 |
t = datetime.datetime.now()
|
229 |
-
filename = os.path.join(
|
|
|
|
|
230 |
if not os.path.isfile(filename):
|
231 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
232 |
image.save(filename)
|
@@ -237,37 +249,56 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|
237 |
"prompt": prompt,
|
238 |
"temperature": float(temperature),
|
239 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
240 |
-
"stop": state.sep
|
241 |
-
|
|
|
|
|
242 |
}
|
243 |
logger.info(f"==== request ====\n{pload}")
|
244 |
|
245 |
-
pload[
|
246 |
|
247 |
state.messages[-1][-1] = "▌"
|
248 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
249 |
|
250 |
try:
|
251 |
# Stream output
|
252 |
-
response = requests.post(
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
254 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
255 |
if chunk:
|
256 |
data = json.loads(chunk.decode())
|
257 |
if data["error_code"] == 0:
|
258 |
-
output = data["text"][len(prompt):].strip()
|
259 |
output = post_process_code(output)
|
260 |
state.messages[-1][-1] = output + "▌"
|
261 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
262 |
else:
|
263 |
output = data["text"] + f" (error_code: {data['error_code']})"
|
264 |
state.messages[-1][-1] = output
|
265 |
-
yield (state, state.to_gradio_chatbot()) + (
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
return
|
267 |
time.sleep(0.03)
|
268 |
except requests.exceptions.RequestException as e:
|
269 |
state.messages[-1][-1] = server_error_msg
|
270 |
-
yield (state, state.to_gradio_chatbot()) + (
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
return
|
272 |
|
273 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
@@ -289,27 +320,30 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
|
|
289 |
}
|
290 |
fout.write(json.dumps(data) + "\n")
|
291 |
|
292 |
-
|
|
|
293 |
# 🌋 LLaVA: Large Language and Vision Assistant
|
294 |
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
|
295 |
-
"""
|
296 |
|
297 |
-
tos_markdown =
|
298 |
### Terms of use
|
299 |
By using this service, users are required to agree to the following terms:
|
300 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
301 |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
302 |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
303 |
-
"""
|
304 |
|
305 |
|
306 |
-
learn_more_markdown =
|
307 |
### License
|
308 |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
309 |
-
"""
|
310 |
|
311 |
|
312 |
-
css =
|
|
|
|
|
313 |
pre {
|
314 |
white-space: pre-wrap; /* Since CSS 2.1 */
|
315 |
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
@@ -318,11 +352,13 @@ pre {
|
|
318 |
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
319 |
}
|
320 |
"""
|
|
|
321 |
|
322 |
|
323 |
def build_demo(embed_mode):
|
324 |
-
textbox = gr.Textbox(
|
325 |
-
placeholder="Enter text and press ENTER", visible=False
|
|
|
326 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo:
|
327 |
state = gr.State()
|
328 |
|
@@ -336,26 +372,55 @@ def build_demo(embed_mode):
|
|
336 |
choices=models,
|
337 |
value=models[0] if len(models) > 0 else "",
|
338 |
interactive=True,
|
339 |
-
show_label=False
|
|
|
340 |
|
341 |
imagebox = gr.Image(type="pil")
|
342 |
image_process_mode = gr.Radio(
|
343 |
["Crop", "Resize", "Pad"],
|
344 |
value="Crop",
|
345 |
-
label="Preprocess for non-square image"
|
|
|
346 |
|
347 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
348 |
-
gr.Examples(
|
349 |
-
[
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
with gr.Column(scale=6):
|
358 |
-
chatbot = grChatbot(
|
|
|
|
|
359 |
with gr.Row():
|
360 |
with gr.Column(scale=8):
|
361 |
textbox.render()
|
@@ -365,7 +430,7 @@ def build_demo(embed_mode):
|
|
365 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
366 |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
367 |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
368 |
-
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
369 |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
370 |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
371 |
|
@@ -376,32 +441,82 @@ def build_demo(embed_mode):
|
|
376 |
|
377 |
# Register listeners
|
378 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
379 |
-
upvote_btn.click(
|
380 |
-
|
381 |
-
|
382 |
-
[
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
[state,
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
if args.model_list_mode == "once":
|
399 |
-
demo.load(
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
elif args.model_list_mode == "reload":
|
403 |
-
demo.load(
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
else:
|
406 |
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
407 |
|
@@ -414,8 +529,9 @@ if __name__ == "__main__":
|
|
414 |
parser.add_argument("--port", type=int)
|
415 |
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
416 |
parser.add_argument("--concurrency-count", type=int, default=8)
|
417 |
-
parser.add_argument(
|
418 |
-
choices=["once", "reload"]
|
|
|
419 |
parser.add_argument("--share", action="store_true")
|
420 |
parser.add_argument("--moderate", action="store_true")
|
421 |
parser.add_argument("--embed", action="store_true")
|
@@ -426,6 +542,6 @@ if __name__ == "__main__":
|
|
426 |
|
427 |
logger.info(args)
|
428 |
demo = build_demo(args.embed)
|
429 |
-
demo.queue(
|
430 |
-
|
431 |
-
|
|
|
1 |
import argparse
|
|
|
2 |
import datetime
|
3 |
+
import hashlib
|
4 |
import json
|
5 |
import os
|
6 |
import time
|
7 |
+
from collections import defaultdict
|
8 |
|
9 |
import gradio as gr
|
10 |
import requests
|
|
|
|
|
|
|
11 |
from llava.constants import LOGDIR
|
12 |
+
from llava.conversation import (SeparatorStyle, conv_templates,
|
13 |
+
default_conversation)
|
|
|
14 |
from llava.serve.gradio_css import code_highlight_css
|
15 |
+
from llava.serve.gradio_patch import Chatbot as grChatbot
|
16 |
+
from llava.utils import (build_logger, moderation_msg, server_error_msg,
|
17 |
+
violates_moderation)
|
18 |
|
19 |
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
20 |
|
|
|
63 |
if "model" in url_params:
|
64 |
model = url_params["model"]
|
65 |
if model in models:
|
66 |
+
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
|
|
67 |
|
68 |
state = default_conversation.copy()
|
69 |
+
return (
|
70 |
+
state,
|
71 |
+
dropdown_update,
|
72 |
+
gr.Chatbot.update(visible=True),
|
73 |
+
gr.Textbox.update(visible=True),
|
74 |
+
gr.Button.update(visible=True),
|
75 |
+
gr.Row.update(visible=True),
|
76 |
+
gr.Accordion.update(visible=True),
|
77 |
+
)
|
78 |
|
79 |
|
80 |
def load_demo_refresh_model_list(request: gr.Request):
|
81 |
logger.info(f"load_demo. ip: {request.client.host}")
|
82 |
models = get_model_list()
|
83 |
state = default_conversation.copy()
|
84 |
+
return (
|
85 |
+
state,
|
86 |
+
gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""),
|
87 |
+
gr.Chatbot.update(visible=True),
|
88 |
+
gr.Textbox.update(visible=True),
|
89 |
+
gr.Button.update(visible=True),
|
90 |
+
gr.Row.update(visible=True),
|
91 |
+
gr.Accordion.update(visible=True),
|
92 |
+
)
|
93 |
|
94 |
|
95 |
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
|
|
148 |
if flagged:
|
149 |
state.skip_next = True
|
150 |
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
151 |
+
no_change_btn,
|
152 |
+
) * 5
|
153 |
|
154 |
text = text[:1536] # Hard cut-off
|
155 |
if image is not None:
|
156 |
text = text[:1200] # Hard cut-off for images
|
157 |
+
if "<image>" not in text:
|
158 |
+
text = text + "\n<image>"
|
159 |
text = (text, image, image_process_mode)
|
160 |
state = default_conversation.copy()
|
161 |
state.append_message(state.roles[0], text)
|
|
|
196 |
template_name = "multimodal"
|
197 |
elif "mpt" in model_name:
|
198 |
template_name = "mpt_text"
|
199 |
+
elif "koala" in model_name: # Hardcode the condition
|
200 |
template_name = "bair_v1"
|
201 |
+
elif "v1" in model_name: # vicuna v1_1/v1_2
|
202 |
template_name = "vicuna_v1_1"
|
203 |
else:
|
204 |
template_name = "v1"
|
|
|
209 |
|
210 |
# Query worker address
|
211 |
controller_url = args.controller_url
|
212 |
+
ret = requests.post(
|
213 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
214 |
+
)
|
215 |
worker_addr = ret.json()["address"]
|
216 |
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
217 |
|
218 |
# No available worker
|
219 |
if worker_addr == "":
|
220 |
state.messages[-1][-1] = server_error_msg
|
221 |
+
yield (
|
222 |
+
state,
|
223 |
+
state.to_gradio_chatbot(),
|
224 |
+
disable_btn,
|
225 |
+
disable_btn,
|
226 |
+
disable_btn,
|
227 |
+
enable_btn,
|
228 |
+
enable_btn,
|
229 |
+
)
|
230 |
return
|
231 |
|
232 |
# Construct prompt
|
|
|
236 |
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
237 |
for image, hash in zip(all_images, all_image_hash):
|
238 |
t = datetime.datetime.now()
|
239 |
+
filename = os.path.join(
|
240 |
+
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
|
241 |
+
)
|
242 |
if not os.path.isfile(filename):
|
243 |
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
244 |
image.save(filename)
|
|
|
249 |
"prompt": prompt,
|
250 |
"temperature": float(temperature),
|
251 |
"max_new_tokens": min(int(max_new_tokens), 1536),
|
252 |
+
"stop": state.sep
|
253 |
+
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
254 |
+
else state.sep2,
|
255 |
+
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
256 |
}
|
257 |
logger.info(f"==== request ====\n{pload}")
|
258 |
|
259 |
+
pload["images"] = state.get_images()
|
260 |
|
261 |
state.messages[-1][-1] = "▌"
|
262 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
263 |
|
264 |
try:
|
265 |
# Stream output
|
266 |
+
response = requests.post(
|
267 |
+
worker_addr + "/worker_generate_stream",
|
268 |
+
headers=headers,
|
269 |
+
json=pload,
|
270 |
+
stream=True,
|
271 |
+
timeout=10,
|
272 |
+
)
|
273 |
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
274 |
if chunk:
|
275 |
data = json.loads(chunk.decode())
|
276 |
if data["error_code"] == 0:
|
277 |
+
output = data["text"][len(prompt) :].strip()
|
278 |
output = post_process_code(output)
|
279 |
state.messages[-1][-1] = output + "▌"
|
280 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
281 |
else:
|
282 |
output = data["text"] + f" (error_code: {data['error_code']})"
|
283 |
state.messages[-1][-1] = output
|
284 |
+
yield (state, state.to_gradio_chatbot()) + (
|
285 |
+
disable_btn,
|
286 |
+
disable_btn,
|
287 |
+
disable_btn,
|
288 |
+
enable_btn,
|
289 |
+
enable_btn,
|
290 |
+
)
|
291 |
return
|
292 |
time.sleep(0.03)
|
293 |
except requests.exceptions.RequestException as e:
|
294 |
state.messages[-1][-1] = server_error_msg
|
295 |
+
yield (state, state.to_gradio_chatbot()) + (
|
296 |
+
disable_btn,
|
297 |
+
disable_btn,
|
298 |
+
disable_btn,
|
299 |
+
enable_btn,
|
300 |
+
enable_btn,
|
301 |
+
)
|
302 |
return
|
303 |
|
304 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
|
|
320 |
}
|
321 |
fout.write(json.dumps(data) + "\n")
|
322 |
|
323 |
+
|
324 |
+
title_markdown = """
|
325 |
# 🌋 LLaVA: Large Language and Vision Assistant
|
326 |
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
|
327 |
+
"""
|
328 |
|
329 |
+
tos_markdown = """
|
330 |
### Terms of use
|
331 |
By using this service, users are required to agree to the following terms:
|
332 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
333 |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
334 |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
335 |
+
"""
|
336 |
|
337 |
|
338 |
+
learn_more_markdown = """
|
339 |
### License
|
340 |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
341 |
+
"""
|
342 |
|
343 |
|
344 |
+
css = (
|
345 |
+
code_highlight_css
|
346 |
+
+ """
|
347 |
pre {
|
348 |
white-space: pre-wrap; /* Since CSS 2.1 */
|
349 |
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
|
|
352 |
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
353 |
}
|
354 |
"""
|
355 |
+
)
|
356 |
|
357 |
|
358 |
def build_demo(embed_mode):
|
359 |
+
textbox = gr.Textbox(
|
360 |
+
show_label=False, placeholder="Enter text and press ENTER", visible=False
|
361 |
+
).style(container=False)
|
362 |
with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo:
|
363 |
state = gr.State()
|
364 |
|
|
|
372 |
choices=models,
|
373 |
value=models[0] if len(models) > 0 else "",
|
374 |
interactive=True,
|
375 |
+
show_label=False,
|
376 |
+
).style(container=False)
|
377 |
|
378 |
imagebox = gr.Image(type="pil")
|
379 |
image_process_mode = gr.Radio(
|
380 |
["Crop", "Resize", "Pad"],
|
381 |
value="Crop",
|
382 |
+
label="Preprocess for non-square image",
|
383 |
+
)
|
384 |
|
385 |
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
386 |
+
gr.Examples(
|
387 |
+
examples=[
|
388 |
+
[
|
389 |
+
f"{cur_dir}/examples/extreme_ironing.jpg",
|
390 |
+
"What is unusual about this image?",
|
391 |
+
],
|
392 |
+
[
|
393 |
+
f"{cur_dir}/examples/waterview.jpg",
|
394 |
+
"What are the things I should be cautious about when I visit here?",
|
395 |
+
],
|
396 |
+
],
|
397 |
+
inputs=[imagebox, textbox],
|
398 |
+
)
|
399 |
+
|
400 |
+
with gr.Accordion(
|
401 |
+
"Parameters", open=False, visible=False
|
402 |
+
) as parameter_row:
|
403 |
+
temperature = gr.Slider(
|
404 |
+
minimum=0.0,
|
405 |
+
maximum=1.0,
|
406 |
+
value=0.2,
|
407 |
+
step=0.1,
|
408 |
+
interactive=True,
|
409 |
+
label="Temperature",
|
410 |
+
)
|
411 |
+
max_output_tokens = gr.Slider(
|
412 |
+
minimum=0,
|
413 |
+
maximum=1024,
|
414 |
+
value=512,
|
415 |
+
step=64,
|
416 |
+
interactive=True,
|
417 |
+
label="Max output tokens",
|
418 |
+
)
|
419 |
|
420 |
with gr.Column(scale=6):
|
421 |
+
chatbot = grChatbot(
|
422 |
+
elem_id="chatbot", label="LLaVA Chatbot", visible=False
|
423 |
+
).style(height=550)
|
424 |
with gr.Row():
|
425 |
with gr.Column(scale=8):
|
426 |
textbox.render()
|
|
|
430 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
431 |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
432 |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
433 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
434 |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
435 |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
436 |
|
|
|
441 |
|
442 |
# Register listeners
|
443 |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
444 |
+
upvote_btn.click(
|
445 |
+
upvote_last_response,
|
446 |
+
[state, model_selector],
|
447 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
448 |
+
)
|
449 |
+
downvote_btn.click(
|
450 |
+
downvote_last_response,
|
451 |
+
[state, model_selector],
|
452 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
453 |
+
)
|
454 |
+
flag_btn.click(
|
455 |
+
flag_last_response,
|
456 |
+
[state, model_selector],
|
457 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
458 |
+
)
|
459 |
+
regenerate_btn.click(
|
460 |
+
regenerate,
|
461 |
+
[state, image_process_mode],
|
462 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
463 |
+
).then(
|
464 |
+
http_bot,
|
465 |
+
[state, model_selector, temperature, max_output_tokens],
|
466 |
+
[state, chatbot] + btn_list,
|
467 |
+
)
|
468 |
+
clear_btn.click(
|
469 |
+
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
|
470 |
+
)
|
471 |
+
|
472 |
+
textbox.submit(
|
473 |
+
add_text,
|
474 |
+
[state, textbox, imagebox, image_process_mode],
|
475 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
476 |
+
).then(
|
477 |
+
http_bot,
|
478 |
+
[state, model_selector, temperature, max_output_tokens],
|
479 |
+
[state, chatbot] + btn_list,
|
480 |
+
)
|
481 |
+
submit_btn.click(
|
482 |
+
add_text,
|
483 |
+
[state, textbox, imagebox, image_process_mode],
|
484 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
485 |
+
).then(
|
486 |
+
http_bot,
|
487 |
+
[state, model_selector, temperature, max_output_tokens],
|
488 |
+
[state, chatbot] + btn_list,
|
489 |
+
)
|
490 |
|
491 |
if args.model_list_mode == "once":
|
492 |
+
demo.load(
|
493 |
+
load_demo,
|
494 |
+
[url_params],
|
495 |
+
[
|
496 |
+
state,
|
497 |
+
model_selector,
|
498 |
+
chatbot,
|
499 |
+
textbox,
|
500 |
+
submit_btn,
|
501 |
+
button_row,
|
502 |
+
parameter_row,
|
503 |
+
],
|
504 |
+
_js=get_window_url_params,
|
505 |
+
)
|
506 |
elif args.model_list_mode == "reload":
|
507 |
+
demo.load(
|
508 |
+
load_demo_refresh_model_list,
|
509 |
+
None,
|
510 |
+
[
|
511 |
+
state,
|
512 |
+
model_selector,
|
513 |
+
chatbot,
|
514 |
+
textbox,
|
515 |
+
submit_btn,
|
516 |
+
button_row,
|
517 |
+
parameter_row,
|
518 |
+
],
|
519 |
+
)
|
520 |
else:
|
521 |
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
522 |
|
|
|
529 |
parser.add_argument("--port", type=int)
|
530 |
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
531 |
parser.add_argument("--concurrency-count", type=int, default=8)
|
532 |
+
parser.add_argument(
|
533 |
+
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
534 |
+
)
|
535 |
parser.add_argument("--share", action="store_true")
|
536 |
parser.add_argument("--moderate", action="store_true")
|
537 |
parser.add_argument("--embed", action="store_true")
|
|
|
542 |
|
543 |
logger.info(args)
|
544 |
demo = build_demo(args.embed)
|
545 |
+
demo.queue(
|
546 |
+
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
547 |
+
).launch(server_name=args.host, server_port=args.port, share=args.share)
|
model/llava/serve/model_worker.py
CHANGED
@@ -4,25 +4,23 @@ A model worker executes the model.
|
|
4 |
import argparse
|
5 |
import asyncio
|
6 |
import dataclasses
|
7 |
-
import logging
|
8 |
import json
|
9 |
-
import
|
10 |
-
from typing import List, Union
|
11 |
import threading
|
|
|
12 |
import uuid
|
|
|
|
|
13 |
|
14 |
-
from fastapi import FastAPI, Request, BackgroundTasks
|
15 |
-
from fastapi.responses import StreamingResponse
|
16 |
import requests
|
17 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
18 |
import torch
|
19 |
import uvicorn
|
20 |
-
from
|
21 |
-
|
22 |
from llava.constants import WORKER_HEART_BEAT_INTERVAL
|
23 |
-
from llava.utils import (build_logger, server_error_msg,
|
24 |
-
pretty_print_semaphore)
|
25 |
from llava.model import *
|
|
|
|
|
26 |
|
27 |
GB = 1 << 30
|
28 |
|
@@ -40,7 +38,6 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
40 |
|
41 |
|
42 |
def heart_beat_worker(controller):
|
43 |
-
|
44 |
while True:
|
45 |
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
46 |
controller.send_heart_beat()
|
@@ -56,38 +53,66 @@ def load_model(model_path, model_name, num_gpus):
|
|
56 |
}
|
57 |
|
58 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
59 |
-
if
|
60 |
-
if
|
61 |
-
model = LlavaMPTForCausalLM.from_pretrained(
|
|
|
|
|
62 |
else:
|
63 |
-
model = LlavaLlamaForCausalLM.from_pretrained(
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
else:
|
67 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
68 |
|
69 |
image_processor = None
|
70 |
|
71 |
-
if
|
72 |
from transformers import CLIPImageProcessor, CLIPVisionModel
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
76 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
77 |
if mm_use_im_start_end:
|
78 |
-
tokenizer.add_tokens(
|
|
|
|
|
79 |
|
80 |
vision_tower = model.get_model().vision_tower[0]
|
81 |
-
if vision_tower.device.type ==
|
82 |
-
vision_tower = CLIPVisionModel.from_pretrained(
|
|
|
|
|
|
|
|
|
83 |
model.get_model().vision_tower[0] = vision_tower
|
84 |
else:
|
85 |
-
vision_tower.to(device=
|
86 |
vision_config = vision_tower.config
|
87 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
|
|
|
|
88 |
vision_config.use_im_start_end = mm_use_im_start_end
|
89 |
if mm_use_im_start_end:
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
if num_gpus == 1:
|
93 |
model.cuda()
|
@@ -101,11 +126,17 @@ def load_model(model_path, model_name, num_gpus):
|
|
101 |
|
102 |
|
103 |
class ModelWorker:
|
104 |
-
def __init__(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
self.controller_addr = controller_addr
|
110 |
self.worker_addr = worker_addr
|
111 |
self.worker_id = worker_id
|
@@ -113,7 +144,7 @@ class ModelWorker:
|
|
113 |
model_path = model_path[:-1]
|
114 |
if model_name is None:
|
115 |
model_paths = model_path.split("/")
|
116 |
-
if model_paths[-1].startswith(
|
117 |
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
118 |
else:
|
119 |
self.model_name = model_paths[-1]
|
@@ -123,13 +154,15 @@ class ModelWorker:
|
|
123 |
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
124 |
self.keep_aspect_ratio = keep_aspect_ratio
|
125 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
|
126 |
-
model_path, self.model_name, num_gpus
|
127 |
-
|
|
|
128 |
|
129 |
if not no_register:
|
130 |
self.register_to_controller()
|
131 |
self.heart_beat_thread = threading.Thread(
|
132 |
-
target=heart_beat_worker, args=(self,)
|
|
|
133 |
self.heart_beat_thread.start()
|
134 |
|
135 |
def register_to_controller(self):
|
@@ -139,23 +172,30 @@ class ModelWorker:
|
|
139 |
data = {
|
140 |
"worker_name": self.worker_addr,
|
141 |
"check_heart_beat": True,
|
142 |
-
"worker_status": self.get_status()
|
143 |
}
|
144 |
r = requests.post(url, json=data)
|
145 |
assert r.status_code == 200
|
146 |
|
147 |
def send_heart_beat(self):
|
148 |
-
logger.info(
|
149 |
-
|
150 |
-
|
|
|
|
|
151 |
|
152 |
url = self.controller_addr + "/receive_heart_beat"
|
153 |
|
154 |
while True:
|
155 |
try:
|
156 |
-
ret = requests.post(
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
159 |
exist = ret.json()["exist"]
|
160 |
break
|
161 |
except requests.exceptions.RequestException as e:
|
@@ -169,8 +209,15 @@ class ModelWorker:
|
|
169 |
if model_semaphore is None:
|
170 |
return 0
|
171 |
else:
|
172 |
-
return
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
def get_status(self):
|
176 |
return {
|
@@ -181,20 +228,30 @@ class ModelWorker:
|
|
181 |
|
182 |
@torch.inference_mode()
|
183 |
def generate_stream(self, params):
|
184 |
-
tokenizer, model, image_processor =
|
|
|
|
|
|
|
|
|
185 |
|
186 |
prompt = params["prompt"]
|
187 |
ori_prompt = prompt
|
188 |
images = params.get("images", None)
|
189 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
190 |
-
from PIL import Image
|
191 |
-
from io import BytesIO
|
192 |
import base64
|
|
|
|
|
|
|
|
|
193 |
assert type(images) is list
|
194 |
if len(images) > 0:
|
195 |
# assert len(images) == 1, "Only support one image for now"
|
196 |
-
images = [
|
197 |
-
|
|
|
|
|
|
|
|
|
198 |
|
199 |
if self.keep_aspect_ratio:
|
200 |
new_images = []
|
@@ -203,21 +260,40 @@ class ModelWorker:
|
|
203 |
aspect_ratio = max_hw / min_hw
|
204 |
max_len, min_len = 448, 224
|
205 |
shortest_edge = int(min(max_len / aspect_ratio, min_len))
|
206 |
-
image = image_processor.preprocess(
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
# replace the image token with the image patch token in the prompt (each occurrence)
|
209 |
-
cur_token_len = (image.shape[1]//14) * (image.shape[2]//14)
|
210 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
|
211 |
-
if getattr(self.model.config,
|
212 |
-
replace_token =
|
|
|
|
|
|
|
|
|
213 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
|
214 |
images = new_images
|
215 |
else:
|
216 |
-
images = image_processor(images, return_tensors=
|
|
|
|
|
217 |
images = images.to(self.model.device, dtype=torch.float16)
|
218 |
-
replace_token =
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
222 |
else:
|
223 |
images = None
|
@@ -249,18 +325,20 @@ class ModelWorker:
|
|
249 |
for i in range(max_new_tokens):
|
250 |
if i == 0:
|
251 |
out = model(
|
252 |
-
torch.as_tensor([input_ids]).cuda(),
|
253 |
-
|
254 |
-
**image_args)
|
255 |
logits = out.logits
|
256 |
past_key_values = out.past_key_values
|
257 |
else:
|
258 |
attention_mask = torch.ones(
|
259 |
-
1, past_key_values[0][0].shape[-2] + 1, device="cuda"
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
264 |
logits = out.logits
|
265 |
past_key_values = out.past_key_values
|
266 |
|
@@ -342,7 +420,9 @@ async def generate_stream(request: Request):
|
|
342 |
worker.send_heart_beat()
|
343 |
generator = worker.generate_stream_gate(params)
|
344 |
background_tasks = BackgroundTasks()
|
345 |
-
background_tasks.add_task(
|
|
|
|
|
346 |
return StreamingResponse(generator, background=background_tasks)
|
347 |
|
348 |
|
@@ -355,13 +435,17 @@ if __name__ == "__main__":
|
|
355 |
parser = argparse.ArgumentParser()
|
356 |
parser.add_argument("--host", type=str, default="localhost")
|
357 |
parser.add_argument("--port", type=int, default=21002)
|
358 |
-
parser.add_argument("--worker-address", type=str,
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
363 |
parser.add_argument("--model-name", type=str)
|
364 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
365 |
parser.add_argument("--keep-aspect-ratio", action="store_true")
|
366 |
parser.add_argument("--num-gpus", type=int, default=1)
|
367 |
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
@@ -371,14 +455,18 @@ if __name__ == "__main__":
|
|
371 |
logger.info(f"args: {args}")
|
372 |
|
373 |
if args.multi_modal:
|
374 |
-
logger.warning(
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
384 |
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
4 |
import argparse
|
5 |
import asyncio
|
6 |
import dataclasses
|
|
|
7 |
import json
|
8 |
+
import logging
|
|
|
9 |
import threading
|
10 |
+
import time
|
11 |
import uuid
|
12 |
+
from functools import partial
|
13 |
+
from typing import List, Union
|
14 |
|
|
|
|
|
15 |
import requests
|
|
|
16 |
import torch
|
17 |
import uvicorn
|
18 |
+
from fastapi import BackgroundTasks, FastAPI, Request
|
19 |
+
from fastapi.responses import StreamingResponse
|
20 |
from llava.constants import WORKER_HEART_BEAT_INTERVAL
|
|
|
|
|
21 |
from llava.model import *
|
22 |
+
from llava.utils import build_logger, pretty_print_semaphore, server_error_msg
|
23 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24 |
|
25 |
GB = 1 << 30
|
26 |
|
|
|
38 |
|
39 |
|
40 |
def heart_beat_worker(controller):
|
|
|
41 |
while True:
|
42 |
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
43 |
controller.send_heart_beat()
|
|
|
53 |
}
|
54 |
|
55 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
56 |
+
if "llava" in model_name.lower():
|
57 |
+
if "mpt" in model_name.lower():
|
58 |
+
model = LlavaMPTForCausalLM.from_pretrained(
|
59 |
+
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
|
60 |
+
)
|
61 |
else:
|
62 |
+
model = LlavaLlamaForCausalLM.from_pretrained(
|
63 |
+
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
|
64 |
+
)
|
65 |
+
elif "mpt" in model_name.lower():
|
66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
67 |
+
model_path,
|
68 |
+
torch_dtype=torch.float16,
|
69 |
+
low_cpu_mem_usage=True,
|
70 |
+
trust_remote_code=True,
|
71 |
+
**kwargs,
|
72 |
+
)
|
73 |
else:
|
74 |
+
model = AutoModelForCausalLM.from_pretrained(
|
75 |
+
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
|
76 |
+
)
|
77 |
|
78 |
image_processor = None
|
79 |
|
80 |
+
if "llava" in model_name.lower():
|
81 |
from transformers import CLIPImageProcessor, CLIPVisionModel
|
82 |
+
|
83 |
+
image_processor = CLIPImageProcessor.from_pretrained(
|
84 |
+
model.config.mm_vision_tower, torch_dtype=torch.float16
|
85 |
+
)
|
86 |
|
87 |
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
88 |
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
89 |
if mm_use_im_start_end:
|
90 |
+
tokenizer.add_tokens(
|
91 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
92 |
+
)
|
93 |
|
94 |
vision_tower = model.get_model().vision_tower[0]
|
95 |
+
if vision_tower.device.type == "meta":
|
96 |
+
vision_tower = CLIPVisionModel.from_pretrained(
|
97 |
+
vision_tower.config._name_or_path,
|
98 |
+
torch_dtype=torch.float16,
|
99 |
+
low_cpu_mem_usage=True,
|
100 |
+
).cuda()
|
101 |
model.get_model().vision_tower[0] = vision_tower
|
102 |
else:
|
103 |
+
vision_tower.to(device="cuda", dtype=torch.float16)
|
104 |
vision_config = vision_tower.config
|
105 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
106 |
+
[DEFAULT_IMAGE_PATCH_TOKEN]
|
107 |
+
)[0]
|
108 |
vision_config.use_im_start_end = mm_use_im_start_end
|
109 |
if mm_use_im_start_end:
|
110 |
+
(
|
111 |
+
vision_config.im_start_token,
|
112 |
+
vision_config.im_end_token,
|
113 |
+
) = tokenizer.convert_tokens_to_ids(
|
114 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
|
115 |
+
)
|
116 |
|
117 |
if num_gpus == 1:
|
118 |
model.cuda()
|
|
|
126 |
|
127 |
|
128 |
class ModelWorker:
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
controller_addr,
|
132 |
+
worker_addr,
|
133 |
+
worker_id,
|
134 |
+
no_register,
|
135 |
+
model_path,
|
136 |
+
model_name,
|
137 |
+
keep_aspect_ratio,
|
138 |
+
num_gpus,
|
139 |
+
):
|
140 |
self.controller_addr = controller_addr
|
141 |
self.worker_addr = worker_addr
|
142 |
self.worker_id = worker_id
|
|
|
144 |
model_path = model_path[:-1]
|
145 |
if model_name is None:
|
146 |
model_paths = model_path.split("/")
|
147 |
+
if model_paths[-1].startswith("checkpoint-"):
|
148 |
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
149 |
else:
|
150 |
self.model_name = model_paths[-1]
|
|
|
154 |
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
155 |
self.keep_aspect_ratio = keep_aspect_ratio
|
156 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
|
157 |
+
model_path, self.model_name, num_gpus
|
158 |
+
)
|
159 |
+
self.is_multimodal = "llava" in model_path.lower()
|
160 |
|
161 |
if not no_register:
|
162 |
self.register_to_controller()
|
163 |
self.heart_beat_thread = threading.Thread(
|
164 |
+
target=heart_beat_worker, args=(self,)
|
165 |
+
)
|
166 |
self.heart_beat_thread.start()
|
167 |
|
168 |
def register_to_controller(self):
|
|
|
172 |
data = {
|
173 |
"worker_name": self.worker_addr,
|
174 |
"check_heart_beat": True,
|
175 |
+
"worker_status": self.get_status(),
|
176 |
}
|
177 |
r = requests.post(url, json=data)
|
178 |
assert r.status_code == 200
|
179 |
|
180 |
def send_heart_beat(self):
|
181 |
+
logger.info(
|
182 |
+
f"Send heart beat. Models: {[self.model_name]}. "
|
183 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
184 |
+
f"global_counter: {global_counter}"
|
185 |
+
)
|
186 |
|
187 |
url = self.controller_addr + "/receive_heart_beat"
|
188 |
|
189 |
while True:
|
190 |
try:
|
191 |
+
ret = requests.post(
|
192 |
+
url,
|
193 |
+
json={
|
194 |
+
"worker_name": self.worker_addr,
|
195 |
+
"queue_length": self.get_queue_length(),
|
196 |
+
},
|
197 |
+
timeout=5,
|
198 |
+
)
|
199 |
exist = ret.json()["exist"]
|
200 |
break
|
201 |
except requests.exceptions.RequestException as e:
|
|
|
209 |
if model_semaphore is None:
|
210 |
return 0
|
211 |
else:
|
212 |
+
return (
|
213 |
+
args.limit_model_concurrency
|
214 |
+
- model_semaphore._value
|
215 |
+
+ (
|
216 |
+
len(model_semaphore._waiters)
|
217 |
+
if model_semaphore._waiters is not None
|
218 |
+
else 0
|
219 |
+
)
|
220 |
+
)
|
221 |
|
222 |
def get_status(self):
|
223 |
return {
|
|
|
228 |
|
229 |
@torch.inference_mode()
|
230 |
def generate_stream(self, params):
|
231 |
+
tokenizer, model, image_processor = (
|
232 |
+
self.tokenizer,
|
233 |
+
self.model,
|
234 |
+
self.image_processor,
|
235 |
+
)
|
236 |
|
237 |
prompt = params["prompt"]
|
238 |
ori_prompt = prompt
|
239 |
images = params.get("images", None)
|
240 |
if images is not None and len(images) > 0 and self.is_multimodal:
|
|
|
|
|
241 |
import base64
|
242 |
+
from io import BytesIO
|
243 |
+
|
244 |
+
from PIL import Image
|
245 |
+
|
246 |
assert type(images) is list
|
247 |
if len(images) > 0:
|
248 |
# assert len(images) == 1, "Only support one image for now"
|
249 |
+
images = [
|
250 |
+
Image.open(BytesIO(base64.b64decode(image))) for image in images
|
251 |
+
]
|
252 |
+
assert len(images) == prompt.count(
|
253 |
+
DEFAULT_IMAGE_TOKEN
|
254 |
+
), "Number of images does not match number of <image> tokens in prompt"
|
255 |
|
256 |
if self.keep_aspect_ratio:
|
257 |
new_images = []
|
|
|
260 |
aspect_ratio = max_hw / min_hw
|
261 |
max_len, min_len = 448, 224
|
262 |
shortest_edge = int(min(max_len / aspect_ratio, min_len))
|
263 |
+
image = image_processor.preprocess(
|
264 |
+
image,
|
265 |
+
return_tensors="pt",
|
266 |
+
do_center_crop=False,
|
267 |
+
size={"shortest_edge": shortest_edge},
|
268 |
+
)["pixel_values"][0]
|
269 |
+
new_images.append(
|
270 |
+
image.to(self.model.device, dtype=torch.float16)
|
271 |
+
)
|
272 |
# replace the image token with the image patch token in the prompt (each occurrence)
|
273 |
+
cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14)
|
274 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
|
275 |
+
if getattr(self.model.config, "mm_use_im_start_end", False):
|
276 |
+
replace_token = (
|
277 |
+
DEFAULT_IM_START_TOKEN
|
278 |
+
+ replace_token
|
279 |
+
+ DEFAULT_IM_END_TOKEN
|
280 |
+
)
|
281 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
|
282 |
images = new_images
|
283 |
else:
|
284 |
+
images = image_processor(images, return_tensors="pt")[
|
285 |
+
"pixel_values"
|
286 |
+
]
|
287 |
images = images.to(self.model.device, dtype=torch.float16)
|
288 |
+
replace_token = (
|
289 |
+
DEFAULT_IMAGE_PATCH_TOKEN * 256
|
290 |
+
) # HACK: 256 is the max image token length hacked
|
291 |
+
if getattr(self.model.config, "mm_use_im_start_end", False):
|
292 |
+
replace_token = (
|
293 |
+
DEFAULT_IM_START_TOKEN
|
294 |
+
+ replace_token
|
295 |
+
+ DEFAULT_IM_END_TOKEN
|
296 |
+
)
|
297 |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
298 |
else:
|
299 |
images = None
|
|
|
325 |
for i in range(max_new_tokens):
|
326 |
if i == 0:
|
327 |
out = model(
|
328 |
+
torch.as_tensor([input_ids]).cuda(), use_cache=True, **image_args
|
329 |
+
)
|
|
|
330 |
logits = out.logits
|
331 |
past_key_values = out.past_key_values
|
332 |
else:
|
333 |
attention_mask = torch.ones(
|
334 |
+
1, past_key_values[0][0].shape[-2] + 1, device="cuda"
|
335 |
+
)
|
336 |
+
out = model(
|
337 |
+
input_ids=torch.as_tensor([[token]], device="cuda"),
|
338 |
+
use_cache=True,
|
339 |
+
attention_mask=attention_mask,
|
340 |
+
past_key_values=past_key_values,
|
341 |
+
)
|
342 |
logits = out.logits
|
343 |
past_key_values = out.past_key_values
|
344 |
|
|
|
420 |
worker.send_heart_beat()
|
421 |
generator = worker.generate_stream_gate(params)
|
422 |
background_tasks = BackgroundTasks()
|
423 |
+
background_tasks.add_task(
|
424 |
+
partial(release_model_semaphore, fn=worker.send_heart_beat)
|
425 |
+
)
|
426 |
return StreamingResponse(generator, background=background_tasks)
|
427 |
|
428 |
|
|
|
435 |
parser = argparse.ArgumentParser()
|
436 |
parser.add_argument("--host", type=str, default="localhost")
|
437 |
parser.add_argument("--port", type=int, default=21002)
|
438 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
439 |
+
parser.add_argument(
|
440 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
441 |
+
)
|
442 |
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
443 |
parser.add_argument("--model-name", type=str)
|
444 |
+
parser.add_argument(
|
445 |
+
"--multi-modal",
|
446 |
+
action="store_true",
|
447 |
+
help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.",
|
448 |
+
)
|
449 |
parser.add_argument("--keep-aspect-ratio", action="store_true")
|
450 |
parser.add_argument("--num-gpus", type=int, default=1)
|
451 |
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
|
|
455 |
logger.info(f"args: {args}")
|
456 |
|
457 |
if args.multi_modal:
|
458 |
+
logger.warning(
|
459 |
+
"Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path."
|
460 |
+
)
|
461 |
+
|
462 |
+
worker = ModelWorker(
|
463 |
+
args.controller_address,
|
464 |
+
args.worker_address,
|
465 |
+
worker_id,
|
466 |
+
args.no_register,
|
467 |
+
args.model_path,
|
468 |
+
args.model_name,
|
469 |
+
args.keep_aspect_ratio,
|
470 |
+
args.num_gpus,
|
471 |
+
)
|
472 |
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
model/llava/serve/test_message.py
CHANGED
@@ -2,7 +2,6 @@ import argparse
|
|
2 |
import json
|
3 |
|
4 |
import requests
|
5 |
-
|
6 |
from llava.conversation import default_conversation
|
7 |
|
8 |
|
@@ -17,8 +16,9 @@ def main():
|
|
17 |
models.sort()
|
18 |
print(f"Models: {models}")
|
19 |
|
20 |
-
ret = requests.post(
|
21 |
-
json={"model": args.model_name}
|
|
|
22 |
worker_addr = ret.json()["address"]
|
23 |
print(f"worker_addr: {worker_addr}")
|
24 |
|
@@ -37,11 +37,17 @@ def main():
|
|
37 |
"temperature": 0.7,
|
38 |
"stop": conv.sep,
|
39 |
}
|
40 |
-
response = requests.post(
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
print(prompt.replace(conv.sep, "\n"), end="")
|
44 |
-
for chunk in response.iter_lines(
|
|
|
|
|
45 |
if chunk:
|
46 |
data = json.loads(chunk.decode("utf-8"))
|
47 |
output = data["text"].split(conv.sep)[-1]
|
@@ -51,12 +57,15 @@ def main():
|
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
parser = argparse.ArgumentParser()
|
54 |
-
parser.add_argument(
|
|
|
|
|
55 |
parser.add_argument("--worker-address", type=str)
|
56 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
57 |
parser.add_argument("--max-new-tokens", type=int, default=32)
|
58 |
-
parser.add_argument(
|
59 |
-
"Tell me a story with more than 1000 words."
|
|
|
60 |
args = parser.parse_args()
|
61 |
|
62 |
main()
|
|
|
2 |
import json
|
3 |
|
4 |
import requests
|
|
|
5 |
from llava.conversation import default_conversation
|
6 |
|
7 |
|
|
|
16 |
models.sort()
|
17 |
print(f"Models: {models}")
|
18 |
|
19 |
+
ret = requests.post(
|
20 |
+
controller_addr + "/get_worker_address", json={"model": args.model_name}
|
21 |
+
)
|
22 |
worker_addr = ret.json()["address"]
|
23 |
print(f"worker_addr: {worker_addr}")
|
24 |
|
|
|
37 |
"temperature": 0.7,
|
38 |
"stop": conv.sep,
|
39 |
}
|
40 |
+
response = requests.post(
|
41 |
+
worker_addr + "/worker_generate_stream",
|
42 |
+
headers=headers,
|
43 |
+
json=pload,
|
44 |
+
stream=True,
|
45 |
+
)
|
46 |
|
47 |
print(prompt.replace(conv.sep, "\n"), end="")
|
48 |
+
for chunk in response.iter_lines(
|
49 |
+
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
50 |
+
):
|
51 |
if chunk:
|
52 |
data = json.loads(chunk.decode("utf-8"))
|
53 |
output = data["text"].split(conv.sep)[-1]
|
|
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument(
|
61 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
62 |
+
)
|
63 |
parser.add_argument("--worker-address", type=str)
|
64 |
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
|
65 |
parser.add_argument("--max-new-tokens", type=int, default=32)
|
66 |
+
parser.add_argument(
|
67 |
+
"--message", type=str, default="Tell me a story with more than 1000 words."
|
68 |
+
)
|
69 |
args = parser.parse_args()
|
70 |
|
71 |
main()
|
model/llava/train/llama_flash_attn_monkey_patch.py
CHANGED
@@ -2,15 +2,13 @@
|
|
2 |
from typing import List, Optional, Tuple
|
3 |
|
4 |
import torch
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
import transformers
|
8 |
-
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
9 |
-
|
10 |
from einops import rearrange
|
11 |
-
|
12 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
13 |
-
from
|
|
|
|
|
14 |
|
15 |
def forward(
|
16 |
self,
|
@@ -19,20 +17,28 @@ def forward(
|
|
19 |
attention_mask: Optional[torch.Tensor] = None,
|
20 |
output_attentions: bool = False,
|
21 |
use_cache: bool = False,
|
22 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
|
23 |
-
Optional[Tuple[torch.Tensor]]]:
|
24 |
"""Input shape: Batch x Time x Channel
|
25 |
-
|
26 |
attention_mask: [bsz, q_len]
|
27 |
"""
|
28 |
bsz, q_len, _ = hidden_states.size()
|
29 |
|
30 |
-
query_states =
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
# [bsz, q_len, nh, hd]
|
37 |
# [bsz, nh, q_len, hd]
|
38 |
|
@@ -42,11 +48,9 @@ def forward(
|
|
42 |
offset = past_key_value[0].shape[-2]
|
43 |
kv_seq_len += offset
|
44 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
45 |
-
query_states, key_states = apply_rotary_pos_emb(
|
46 |
-
|
47 |
-
|
48 |
-
sin,
|
49 |
-
offset=offset)
|
50 |
# [bsz, nh, t, hd]
|
51 |
assert not output_attentions, "output_attentions is not supported"
|
52 |
assert not use_cache, "use_cache is not supported"
|
@@ -56,47 +60,55 @@ def forward(
|
|
56 |
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
57 |
|
58 |
# transform the data into the format required by flash attention
|
59 |
-
qkv = torch.stack(
|
60 |
-
|
|
|
|
|
61 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
62 |
# the attention_mask should be the same as the key_padding_mask
|
63 |
key_padding_mask = attention_mask
|
64 |
|
65 |
-
|
66 |
if key_padding_mask is None:
|
67 |
-
qkv = rearrange(qkv,
|
68 |
max_s = q_len
|
69 |
-
cu_q_lens = torch.arange(
|
70 |
-
|
|
|
71 |
output = flash_attn_unpadded_qkvpacked_func(
|
72 |
-
qkv, cu_q_lens, max_s, 0.0,
|
73 |
-
softmax_scale=None, causal=True
|
74 |
)
|
75 |
-
output = rearrange(output,
|
76 |
else:
|
77 |
nheads = qkv.shape[-2]
|
78 |
-
x = rearrange(qkv,
|
79 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
80 |
-
x_unpad = rearrange(
|
|
|
|
|
81 |
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
82 |
-
x_unpad, cu_q_lens, max_s, 0.0,
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
)
|
85 |
-
|
86 |
-
indices, bsz, q_len),
|
87 |
-
'b s (h d) -> b s h d', h=nheads)
|
88 |
-
return self.o_proj(rearrange(output,
|
89 |
-
'b s h d -> b s (h d)')), None, None
|
90 |
|
91 |
|
92 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
93 |
# requires the attention mask to be the same as the key_padding_mask
|
94 |
-
def _prepare_decoder_attention_mask(
|
95 |
-
|
|
|
96 |
# [bsz, seq_len]
|
97 |
return attention_mask
|
98 |
|
99 |
|
100 |
def replace_llama_attn_with_flash_attn():
|
101 |
-
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask =
|
|
|
|
|
102 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
|
2 |
from typing import List, Optional, Tuple
|
3 |
|
4 |
import torch
|
|
|
|
|
5 |
import transformers
|
|
|
|
|
6 |
from einops import rearrange
|
7 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
8 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
9 |
+
from torch import nn
|
10 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
11 |
+
|
12 |
|
13 |
def forward(
|
14 |
self,
|
|
|
17 |
attention_mask: Optional[torch.Tensor] = None,
|
18 |
output_attentions: bool = False,
|
19 |
use_cache: bool = False,
|
20 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
21 |
"""Input shape: Batch x Time x Channel
|
22 |
+
|
23 |
attention_mask: [bsz, q_len]
|
24 |
"""
|
25 |
bsz, q_len, _ = hidden_states.size()
|
26 |
|
27 |
+
query_states = (
|
28 |
+
self.q_proj(hidden_states)
|
29 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
30 |
+
.transpose(1, 2)
|
31 |
+
)
|
32 |
+
key_states = (
|
33 |
+
self.k_proj(hidden_states)
|
34 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
35 |
+
.transpose(1, 2)
|
36 |
+
)
|
37 |
+
value_states = (
|
38 |
+
self.v_proj(hidden_states)
|
39 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
40 |
+
.transpose(1, 2)
|
41 |
+
)
|
42 |
# [bsz, q_len, nh, hd]
|
43 |
# [bsz, nh, q_len, hd]
|
44 |
|
|
|
48 |
offset = past_key_value[0].shape[-2]
|
49 |
kv_seq_len += offset
|
50 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
51 |
+
query_states, key_states = apply_rotary_pos_emb(
|
52 |
+
query_states, key_states, cos, sin, offset=offset
|
53 |
+
)
|
|
|
|
|
54 |
# [bsz, nh, t, hd]
|
55 |
assert not output_attentions, "output_attentions is not supported"
|
56 |
assert not use_cache, "use_cache is not supported"
|
|
|
60 |
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
61 |
|
62 |
# transform the data into the format required by flash attention
|
63 |
+
qkv = torch.stack(
|
64 |
+
[query_states, key_states, value_states], dim=2
|
65 |
+
) # [bsz, nh, 3, q_len, hd]
|
66 |
+
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
67 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
68 |
# the attention_mask should be the same as the key_padding_mask
|
69 |
key_padding_mask = attention_mask
|
70 |
|
|
|
71 |
if key_padding_mask is None:
|
72 |
+
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
73 |
max_s = q_len
|
74 |
+
cu_q_lens = torch.arange(
|
75 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
76 |
+
)
|
77 |
output = flash_attn_unpadded_qkvpacked_func(
|
78 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
|
|
79 |
)
|
80 |
+
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
81 |
else:
|
82 |
nheads = qkv.shape[-2]
|
83 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
84 |
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
85 |
+
x_unpad = rearrange(
|
86 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
87 |
+
)
|
88 |
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
89 |
+
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
90 |
+
)
|
91 |
+
output = rearrange(
|
92 |
+
pad_input(
|
93 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
94 |
+
),
|
95 |
+
"b s (h d) -> b s h d",
|
96 |
+
h=nheads,
|
97 |
)
|
98 |
+
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
102 |
# requires the attention mask to be the same as the key_padding_mask
|
103 |
+
def _prepare_decoder_attention_mask(
|
104 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
105 |
+
):
|
106 |
# [bsz, seq_len]
|
107 |
return attention_mask
|
108 |
|
109 |
|
110 |
def replace_llama_attn_with_flash_attn():
|
111 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
112 |
+
_prepare_decoder_attention_mask
|
113 |
+
)
|
114 |
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
model/llava/train/llava_trainer.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
from transformers import Trainer
|
6 |
-
from typing import Dict, Optional, Sequence
|
7 |
|
8 |
|
9 |
def unwrap_model(model: nn.Module) -> nn.Module:
|
@@ -21,9 +21,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
|
|
21 |
|
22 |
|
23 |
class LLaVATrainer(Trainer):
|
24 |
-
|
25 |
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
26 |
-
if getattr(self.args,
|
27 |
# Save the model
|
28 |
_state_dict = state_dict
|
29 |
if _state_dict is None:
|
@@ -32,18 +31,23 @@ class LLaVATrainer(Trainer):
|
|
32 |
_state_dict = model_to_save.state_dict()
|
33 |
|
34 |
weight_to_save = {}
|
35 |
-
keys_to_match = [
|
36 |
for k, v in _state_dict.items():
|
37 |
if any(key_match in k for key_match in keys_to_match):
|
38 |
weight_to_save[k] = v
|
39 |
|
40 |
-
current_folder = output_dir.split(
|
41 |
parent_folder = os.path.dirname(output_dir)
|
42 |
-
if current_folder.startswith(
|
43 |
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
44 |
os.makedirs(mm_projector_folder, exist_ok=True)
|
45 |
-
torch.save(
|
|
|
|
|
|
|
46 |
else:
|
47 |
-
torch.save(
|
|
|
|
|
48 |
|
49 |
super(LLaVATrainer, self)._save(output_dir, state_dict)
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, Optional, Sequence
|
3 |
+
|
4 |
import torch
|
5 |
import torch.nn as nn
|
|
|
6 |
from transformers import Trainer
|
|
|
7 |
|
8 |
|
9 |
def unwrap_model(model: nn.Module) -> nn.Module:
|
|
|
21 |
|
22 |
|
23 |
class LLaVATrainer(Trainer):
|
|
|
24 |
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
25 |
+
if getattr(self.args, "tune_mm_mlp_adapter", False):
|
26 |
# Save the model
|
27 |
_state_dict = state_dict
|
28 |
if _state_dict is None:
|
|
|
31 |
_state_dict = model_to_save.state_dict()
|
32 |
|
33 |
weight_to_save = {}
|
34 |
+
keys_to_match = ["mm_projector", "embed_tokens", "embed_in"]
|
35 |
for k, v in _state_dict.items():
|
36 |
if any(key_match in k for key_match in keys_to_match):
|
37 |
weight_to_save[k] = v
|
38 |
|
39 |
+
current_folder = output_dir.split("/")[-1]
|
40 |
parent_folder = os.path.dirname(output_dir)
|
41 |
+
if current_folder.startswith("checkpoint-"):
|
42 |
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
43 |
os.makedirs(mm_projector_folder, exist_ok=True)
|
44 |
+
torch.save(
|
45 |
+
weight_to_save,
|
46 |
+
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
|
47 |
+
)
|
48 |
else:
|
49 |
+
torch.save(
|
50 |
+
weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
|
51 |
+
)
|
52 |
|
53 |
super(LLaVATrainer, self)._save(output_dir, state_dict)
|
model/llava/train/train.py
CHANGED
@@ -14,25 +14,22 @@
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
|
17 |
-
import os
|
18 |
import copy
|
19 |
-
from dataclasses import dataclass, field
|
20 |
import json
|
21 |
import logging
|
|
|
22 |
import pathlib
|
|
|
23 |
from typing import Dict, Optional, Sequence
|
24 |
|
25 |
import torch
|
26 |
-
|
27 |
import transformers
|
28 |
-
from torch.utils.data import Dataset
|
29 |
-
from llava.train.llava_trainer import LLaVATrainer
|
30 |
-
|
31 |
from llava import conversation as conversation_lib
|
32 |
from llava.model import *
|
33 |
-
|
34 |
from PIL import Image
|
35 |
-
|
36 |
|
37 |
# TODO: import and use code from ../data/dataset.py
|
38 |
|
@@ -54,21 +51,24 @@ class ModelArguments:
|
|
54 |
freeze_backbone: bool = field(default=False)
|
55 |
tune_mm_mlp_adapter: bool = field(default=False)
|
56 |
vision_tower: Optional[str] = field(default=None)
|
57 |
-
mm_vision_select_layer: Optional[int] = field(
|
|
|
|
|
58 |
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
59 |
mm_use_im_start_end: bool = field(default=False)
|
60 |
|
61 |
|
62 |
@dataclass
|
63 |
class DataArguments:
|
64 |
-
data_path: str = field(
|
65 |
-
|
|
|
66 |
lazy_preprocess: bool = False
|
67 |
is_multimodal: bool = False
|
68 |
sep_image_conv_front: bool = False
|
69 |
image_token_len: int = 0
|
70 |
image_folder: Optional[str] = field(default=None)
|
71 |
-
image_aspect_ratio: str =
|
72 |
|
73 |
|
74 |
@dataclass
|
@@ -81,21 +81,16 @@ class TrainingArguments(transformers.TrainingArguments):
|
|
81 |
model_max_length: int = field(
|
82 |
default=512,
|
83 |
metadata={
|
84 |
-
"help":
|
85 |
-
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
86 |
},
|
87 |
)
|
88 |
|
89 |
|
90 |
-
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
91 |
-
output_dir: str):
|
92 |
"""Collects the state dict and dump to disk."""
|
93 |
state_dict = trainer.model.state_dict()
|
94 |
if trainer.args.should_save:
|
95 |
-
cpu_state_dict = {
|
96 |
-
key: value.cpu()
|
97 |
-
for key, value in state_dict.items()
|
98 |
-
}
|
99 |
del state_dict
|
100 |
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
101 |
|
@@ -117,16 +112,19 @@ def smart_tokenizer_and_embedding_resize(
|
|
117 |
output_embeddings = model.get_output_embeddings().weight.data
|
118 |
|
119 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
120 |
-
dim=0, keepdim=True
|
|
|
121 |
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
122 |
-
dim=0, keepdim=True
|
|
|
123 |
|
124 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
125 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
126 |
|
127 |
|
128 |
-
def _tokenize_fn(
|
129 |
-
|
|
|
130 |
"""Tokenize a list of strings."""
|
131 |
tokenized_list = [
|
132 |
tokenizer(
|
@@ -135,11 +133,10 @@ def _tokenize_fn(strings: Sequence[str],
|
|
135 |
padding="longest",
|
136 |
max_length=tokenizer.model_max_length,
|
137 |
truncation=True,
|
138 |
-
)
|
139 |
-
|
140 |
-
input_ids = labels = [
|
141 |
-
tokenized.input_ids[0] for tokenized in tokenized_list
|
142 |
]
|
|
|
143 |
input_ids_lens = labels_lens = [
|
144 |
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
145 |
for tokenized in tokenized_list
|
@@ -159,7 +156,7 @@ def _mask_targets(target, tokenized_lens, speakers):
|
|
159 |
target[:cur_idx] = IGNORE_INDEX
|
160 |
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
161 |
if speaker == "human":
|
162 |
-
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
163 |
cur_idx += tokenized_len
|
164 |
|
165 |
|
@@ -175,9 +172,10 @@ def _add_speaker_and_signal(header, source, get_conversation=True):
|
|
175 |
elif from_str.lower() == "gpt":
|
176 |
from_str = conversation_lib.default_conversation.roles[1]
|
177 |
else:
|
178 |
-
from_str =
|
179 |
-
sentence["value"] = (
|
180 |
-
|
|
|
181 |
if get_conversation:
|
182 |
conversation += sentence["value"]
|
183 |
conversation += BEGIN_SIGNAL
|
@@ -189,22 +187,34 @@ def preprocess_multimodal(
|
|
189 |
multimodal_cfg: dict,
|
190 |
cur_token_len: int,
|
191 |
) -> Dict:
|
192 |
-
is_multimodal = multimodal_cfg[
|
193 |
# image_token_len = multimodal_cfg['image_token_len']
|
194 |
image_token_len = cur_token_len
|
195 |
if not is_multimodal:
|
196 |
return sources
|
197 |
|
198 |
for source in sources:
|
199 |
-
if multimodal_cfg[
|
200 |
-
assert DEFAULT_IMAGE_TOKEN in source[0][
|
201 |
-
source[0][
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
for sentence in source:
|
204 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
205 |
-
if multimodal_cfg[
|
206 |
-
replace_token =
|
207 |
-
|
|
|
|
|
|
|
|
|
208 |
|
209 |
return sources
|
210 |
|
@@ -279,6 +289,7 @@ def preprocess_v1(
|
|
279 |
labels=targets,
|
280 |
)
|
281 |
|
|
|
282 |
def preprocess_mpt(
|
283 |
sources,
|
284 |
tokenizer: transformers.PreTrainedTokenizer,
|
@@ -317,9 +328,11 @@ def preprocess_mpt(
|
|
317 |
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
318 |
|
319 |
rounds = conversation.split(conv.sep)
|
320 |
-
re_rounds = [conv.sep.join(rounds[:3])]
|
321 |
for conv_idx in range(3, len(rounds), 2):
|
322 |
-
re_rounds.append(
|
|
|
|
|
323 |
cur_len = 0
|
324 |
target[:cur_len] = IGNORE_INDEX
|
325 |
for i, rou in enumerate(re_rounds):
|
@@ -330,7 +343,9 @@ def preprocess_mpt(
|
|
330 |
if len(parts) != 2:
|
331 |
break
|
332 |
parts[0] += sep
|
333 |
-
round_len = len(tokenizer(rou).input_ids) + len(
|
|
|
|
|
334 |
instruction_len = len(tokenizer(parts[0]).input_ids)
|
335 |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
336 |
|
@@ -377,8 +392,9 @@ def preprocess(
|
|
377 |
input_ids = conversations_tokenized["input_ids"]
|
378 |
targets = copy.deepcopy(input_ids)
|
379 |
for target, source in zip(targets, sources):
|
380 |
-
tokenized_lens = _tokenize_fn(
|
381 |
-
|
|
|
382 |
speakers = [sentence["from"] for sentence in source]
|
383 |
_mask_targets(target, tokenized_lens, speakers)
|
384 |
|
@@ -388,8 +404,7 @@ def preprocess(
|
|
388 |
class SupervisedDataset(Dataset):
|
389 |
"""Dataset for supervised fine-tuning."""
|
390 |
|
391 |
-
def __init__(self, data_path: str,
|
392 |
-
tokenizer: transformers.PreTrainedTokenizer):
|
393 |
super(SupervisedDataset, self).__init__()
|
394 |
logging.warning("Loading data...")
|
395 |
list_data_dict = json.load(open(data_path, "r"))
|
@@ -411,9 +426,12 @@ class SupervisedDataset(Dataset):
|
|
411 |
class LazySupervisedDataset(Dataset):
|
412 |
"""Dataset for supervised fine-tuning."""
|
413 |
|
414 |
-
def __init__(
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
417 |
super(LazySupervisedDataset, self).__init__()
|
418 |
logging.warning("Loading data...")
|
419 |
list_data_dict = json.load(open(data_path, "r"))
|
@@ -431,54 +449,74 @@ class LazySupervisedDataset(Dataset):
|
|
431 |
if isinstance(i, int):
|
432 |
sources = [sources]
|
433 |
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
434 |
-
if
|
435 |
-
image_file = self.list_data_dict[i][
|
436 |
-
image_folder = self.multimodal_cfg[
|
437 |
-
processor = self.multimodal_cfg[
|
438 |
-
image = Image.open(os.path.join(image_folder, image_file)).convert(
|
439 |
-
if self.multimodal_cfg[
|
440 |
max_hw, min_hw = max(image.size), min(image.size)
|
441 |
aspect_ratio = max_hw / min_hw
|
442 |
max_len, min_len = 448, 224
|
443 |
shortest_edge = int(min(max_len / aspect_ratio, min_len))
|
444 |
-
image = processor.preprocess(
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
def expand2square(pil_img, background_color):
|
447 |
width, height = pil_img.size
|
448 |
if width == height:
|
449 |
return pil_img
|
450 |
elif width > height:
|
451 |
-
result = Image.new(
|
|
|
|
|
452 |
result.paste(pil_img, (0, (width - height) // 2))
|
453 |
return result
|
454 |
else:
|
455 |
-
result = Image.new(
|
|
|
|
|
456 |
result.paste(pil_img, ((height - width) // 2, 0))
|
457 |
return result
|
458 |
-
|
459 |
-
image =
|
|
|
|
|
|
|
|
|
|
|
460 |
else:
|
461 |
-
image = processor.preprocess(image, return_tensors=
|
462 |
-
|
|
|
|
|
|
|
|
|
463 |
sources = preprocess_multimodal(
|
464 |
copy.deepcopy([e["conversations"] for e in sources]),
|
465 |
-
self.multimodal_cfg,
|
|
|
|
|
466 |
else:
|
467 |
sources = copy.deepcopy([e["conversations"] for e in sources])
|
468 |
-
data_dict = preprocess(
|
469 |
-
sources,
|
470 |
-
self.tokenizer)
|
471 |
if isinstance(i, int):
|
472 |
-
data_dict = dict(
|
473 |
-
|
|
|
474 |
|
475 |
# image exist in the data
|
476 |
-
if
|
477 |
-
data_dict[
|
478 |
-
elif self.multimodal_cfg[
|
479 |
# image does not exist in the data, but the model is multimodal
|
480 |
-
crop_size = self.multimodal_cfg[
|
481 |
-
data_dict[
|
482 |
return data_dict
|
483 |
|
484 |
|
@@ -489,59 +527,65 @@ class DataCollatorForSupervisedDataset(object):
|
|
489 |
tokenizer: transformers.PreTrainedTokenizer
|
490 |
|
491 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
492 |
-
input_ids, labels = tuple(
|
493 |
-
|
|
|
494 |
input_ids = torch.nn.utils.rnn.pad_sequence(
|
495 |
-
input_ids,
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
padding_value=IGNORE_INDEX)
|
501 |
batch = dict(
|
502 |
input_ids=input_ids,
|
503 |
labels=labels,
|
504 |
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
505 |
)
|
506 |
|
507 |
-
if
|
508 |
-
images = [instance[
|
509 |
if all(x is not None and x.shape == images[0].shape for x in images):
|
510 |
-
batch[
|
511 |
else:
|
512 |
-
batch[
|
513 |
|
514 |
return batch
|
515 |
|
516 |
|
517 |
-
def make_supervised_data_module(
|
518 |
-
|
|
|
519 |
"""Make dataset and collator for supervised fine-tuning."""
|
520 |
-
dataset_cls = (
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
|
|
|
|
|
|
|
|
532 |
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
533 |
-
return dict(
|
534 |
-
|
535 |
-
|
536 |
|
537 |
|
538 |
def train():
|
539 |
parser = transformers.HfArgumentParser(
|
540 |
-
(ModelArguments, DataArguments, TrainingArguments)
|
|
|
541 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
542 |
|
543 |
if model_args.vision_tower is not None:
|
544 |
-
if
|
545 |
model = LlavaMPTForCausalLM.from_pretrained(
|
546 |
model_args.model_name_or_path,
|
547 |
cache_dir=training_args.cache_dir,
|
@@ -561,12 +605,12 @@ def train():
|
|
561 |
if model_args.freeze_backbone:
|
562 |
model.model.requires_grad_(False)
|
563 |
|
564 |
-
if
|
565 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
566 |
model_args.model_name_or_path,
|
567 |
cache_dir=training_args.cache_dir,
|
568 |
model_max_length=training_args.model_max_length,
|
569 |
-
padding_side="right"
|
570 |
)
|
571 |
else:
|
572 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
@@ -585,23 +629,29 @@ def train():
|
|
585 |
model=model,
|
586 |
)
|
587 |
if "llama" in model_args.model_name_or_path:
|
588 |
-
tokenizer.add_special_tokens(
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
|
|
|
|
593 |
else:
|
594 |
tokenizer.pad_token = tokenizer.unk_token
|
595 |
if "mpt" in model_args.model_name_or_path:
|
596 |
-
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
|
|
|
|
597 |
else:
|
598 |
-
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
|
|
|
|
599 |
|
600 |
if model_args.vision_tower is not None:
|
601 |
model_vision_dict = model.get_model().initialize_vision_modules(
|
602 |
vision_tower=model_args.vision_tower,
|
603 |
mm_vision_select_layer=model_args.mm_vision_select_layer,
|
604 |
-
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter
|
605 |
)
|
606 |
dtype = torch.float32
|
607 |
if training_args.fp16:
|
@@ -609,13 +659,15 @@ def train():
|
|
609 |
if training_args.bf16:
|
610 |
dtype = torch.bfloat16
|
611 |
model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device)
|
612 |
-
vision_config = model_vision_dict[
|
613 |
|
614 |
-
data_args.image_token_len = model_vision_dict[
|
615 |
-
data_args.image_processor = model_vision_dict[
|
616 |
data_args.is_multimodal = True
|
617 |
|
618 |
-
model.config.tune_mm_mlp_adapter =
|
|
|
|
|
619 |
if model_args.tune_mm_mlp_adapter:
|
620 |
model.requires_grad_(False)
|
621 |
for p in model.get_model().mm_projector.parameters():
|
@@ -626,45 +678,66 @@ def train():
|
|
626 |
for p in model.get_model().mm_projector.parameters():
|
627 |
p.requires_grad = False
|
628 |
|
629 |
-
model.config.mm_use_im_start_end =
|
630 |
-
|
|
|
|
|
|
|
|
|
631 |
model.config.sep_image_conv_front = data_args.sep_image_conv_front
|
632 |
-
model.initialize_vision_tokenizer(
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
|
636 |
if len(params_no_grad) > 0:
|
637 |
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
|
638 |
if len(params_no_grad) < 10:
|
639 |
-
print(
|
|
|
|
|
|
|
|
|
640 |
else:
|
641 |
-
print(
|
642 |
-
|
643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
|
645 |
-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
646 |
def patch_FSDP_use_orig_params(func):
|
647 |
def wrap_func(*args, **kwargs):
|
648 |
-
use_orig_params = kwargs.pop(
|
649 |
return func(*args, **kwargs, use_orig_params=use_orig_params)
|
|
|
650 |
return wrap_func
|
651 |
|
652 |
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
|
653 |
|
654 |
-
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
args=training_args,
|
659 |
-
**data_module)
|
660 |
|
661 |
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
662 |
trainer.train(resume_from_checkpoint=True)
|
663 |
else:
|
664 |
trainer.train()
|
665 |
trainer.save_state()
|
666 |
-
safe_save_model_for_hf_trainer(trainer=trainer,
|
667 |
-
output_dir=training_args.output_dir)
|
668 |
|
669 |
|
670 |
if __name__ == "__main__":
|
|
|
14 |
# See the License for the specific language governing permissions and
|
15 |
# limitations under the License.
|
16 |
|
|
|
17 |
import copy
|
|
|
18 |
import json
|
19 |
import logging
|
20 |
+
import os
|
21 |
import pathlib
|
22 |
+
from dataclasses import dataclass, field
|
23 |
from typing import Dict, Optional, Sequence
|
24 |
|
25 |
import torch
|
26 |
+
import torch.nn as nn
|
27 |
import transformers
|
|
|
|
|
|
|
28 |
from llava import conversation as conversation_lib
|
29 |
from llava.model import *
|
30 |
+
from llava.train.llava_trainer import LLaVATrainer
|
31 |
from PIL import Image
|
32 |
+
from torch.utils.data import Dataset
|
33 |
|
34 |
# TODO: import and use code from ../data/dataset.py
|
35 |
|
|
|
51 |
freeze_backbone: bool = field(default=False)
|
52 |
tune_mm_mlp_adapter: bool = field(default=False)
|
53 |
vision_tower: Optional[str] = field(default=None)
|
54 |
+
mm_vision_select_layer: Optional[int] = field(
|
55 |
+
default=-1
|
56 |
+
) # default to the last layer
|
57 |
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
58 |
mm_use_im_start_end: bool = field(default=False)
|
59 |
|
60 |
|
61 |
@dataclass
|
62 |
class DataArguments:
|
63 |
+
data_path: str = field(
|
64 |
+
default=None, metadata={"help": "Path to the training data."}
|
65 |
+
)
|
66 |
lazy_preprocess: bool = False
|
67 |
is_multimodal: bool = False
|
68 |
sep_image_conv_front: bool = False
|
69 |
image_token_len: int = 0
|
70 |
image_folder: Optional[str] = field(default=None)
|
71 |
+
image_aspect_ratio: str = "square"
|
72 |
|
73 |
|
74 |
@dataclass
|
|
|
81 |
model_max_length: int = field(
|
82 |
default=512,
|
83 |
metadata={
|
84 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
|
|
85 |
},
|
86 |
)
|
87 |
|
88 |
|
89 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
|
|
90 |
"""Collects the state dict and dump to disk."""
|
91 |
state_dict = trainer.model.state_dict()
|
92 |
if trainer.args.should_save:
|
93 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
|
|
|
|
|
|
94 |
del state_dict
|
95 |
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
96 |
|
|
|
112 |
output_embeddings = model.get_output_embeddings().weight.data
|
113 |
|
114 |
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
115 |
+
dim=0, keepdim=True
|
116 |
+
)
|
117 |
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
118 |
+
dim=0, keepdim=True
|
119 |
+
)
|
120 |
|
121 |
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
122 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
123 |
|
124 |
|
125 |
+
def _tokenize_fn(
|
126 |
+
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
|
127 |
+
) -> Dict:
|
128 |
"""Tokenize a list of strings."""
|
129 |
tokenized_list = [
|
130 |
tokenizer(
|
|
|
133 |
padding="longest",
|
134 |
max_length=tokenizer.model_max_length,
|
135 |
truncation=True,
|
136 |
+
)
|
137 |
+
for text in strings
|
|
|
|
|
138 |
]
|
139 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
140 |
input_ids_lens = labels_lens = [
|
141 |
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
142 |
for tokenized in tokenized_list
|
|
|
156 |
target[:cur_idx] = IGNORE_INDEX
|
157 |
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
158 |
if speaker == "human":
|
159 |
+
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
|
160 |
cur_idx += tokenized_len
|
161 |
|
162 |
|
|
|
172 |
elif from_str.lower() == "gpt":
|
173 |
from_str = conversation_lib.default_conversation.roles[1]
|
174 |
else:
|
175 |
+
from_str = "unknown"
|
176 |
+
sentence["value"] = (
|
177 |
+
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
|
178 |
+
)
|
179 |
if get_conversation:
|
180 |
conversation += sentence["value"]
|
181 |
conversation += BEGIN_SIGNAL
|
|
|
187 |
multimodal_cfg: dict,
|
188 |
cur_token_len: int,
|
189 |
) -> Dict:
|
190 |
+
is_multimodal = multimodal_cfg["is_multimodal"]
|
191 |
# image_token_len = multimodal_cfg['image_token_len']
|
192 |
image_token_len = cur_token_len
|
193 |
if not is_multimodal:
|
194 |
return sources
|
195 |
|
196 |
for source in sources:
|
197 |
+
if multimodal_cfg["sep_image_conv_front"]:
|
198 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
|
199 |
+
source[0]["value"] = (
|
200 |
+
source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
201 |
+
)
|
202 |
+
source[0]["value"] = (
|
203 |
+
DEFAULT_IMAGE_TOKEN
|
204 |
+
+ conversation_lib.default_conversation.sep
|
205 |
+
+ conversation_lib.default_conversation.roles[0]
|
206 |
+
+ ": "
|
207 |
+
+ source[0]["value"]
|
208 |
+
)
|
209 |
for sentence in source:
|
210 |
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
211 |
+
if multimodal_cfg["use_im_start_end"]:
|
212 |
+
replace_token = (
|
213 |
+
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
214 |
+
)
|
215 |
+
sentence["value"] = sentence["value"].replace(
|
216 |
+
DEFAULT_IMAGE_TOKEN, replace_token
|
217 |
+
)
|
218 |
|
219 |
return sources
|
220 |
|
|
|
289 |
labels=targets,
|
290 |
)
|
291 |
|
292 |
+
|
293 |
def preprocess_mpt(
|
294 |
sources,
|
295 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
328 |
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
329 |
|
330 |
rounds = conversation.split(conv.sep)
|
331 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
332 |
for conv_idx in range(3, len(rounds), 2):
|
333 |
+
re_rounds.append(
|
334 |
+
conv.sep.join(rounds[conv_idx : conv_idx + 2])
|
335 |
+
) # user + gpt
|
336 |
cur_len = 0
|
337 |
target[:cur_len] = IGNORE_INDEX
|
338 |
for i, rou in enumerate(re_rounds):
|
|
|
343 |
if len(parts) != 2:
|
344 |
break
|
345 |
parts[0] += sep
|
346 |
+
round_len = len(tokenizer(rou).input_ids) + len(
|
347 |
+
tokenizer(conv.sep).input_ids
|
348 |
+
)
|
349 |
instruction_len = len(tokenizer(parts[0]).input_ids)
|
350 |
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
351 |
|
|
|
392 |
input_ids = conversations_tokenized["input_ids"]
|
393 |
targets = copy.deepcopy(input_ids)
|
394 |
for target, source in zip(targets, sources):
|
395 |
+
tokenized_lens = _tokenize_fn(
|
396 |
+
[header] + [s["value"] for s in source], tokenizer
|
397 |
+
)["input_ids_lens"]
|
398 |
speakers = [sentence["from"] for sentence in source]
|
399 |
_mask_targets(target, tokenized_lens, speakers)
|
400 |
|
|
|
404 |
class SupervisedDataset(Dataset):
|
405 |
"""Dataset for supervised fine-tuning."""
|
406 |
|
407 |
+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
|
|
408 |
super(SupervisedDataset, self).__init__()
|
409 |
logging.warning("Loading data...")
|
410 |
list_data_dict = json.load(open(data_path, "r"))
|
|
|
426 |
class LazySupervisedDataset(Dataset):
|
427 |
"""Dataset for supervised fine-tuning."""
|
428 |
|
429 |
+
def __init__(
|
430 |
+
self,
|
431 |
+
data_path: str,
|
432 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
433 |
+
multimodal_cfg: dict,
|
434 |
+
):
|
435 |
super(LazySupervisedDataset, self).__init__()
|
436 |
logging.warning("Loading data...")
|
437 |
list_data_dict = json.load(open(data_path, "r"))
|
|
|
449 |
if isinstance(i, int):
|
450 |
sources = [sources]
|
451 |
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
452 |
+
if "image" in sources[0]:
|
453 |
+
image_file = self.list_data_dict[i]["image"]
|
454 |
+
image_folder = self.multimodal_cfg["image_folder"]
|
455 |
+
processor = self.multimodal_cfg["image_processor"]
|
456 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
|
457 |
+
if self.multimodal_cfg["image_aspect_ratio"] == "keep":
|
458 |
max_hw, min_hw = max(image.size), min(image.size)
|
459 |
aspect_ratio = max_hw / min_hw
|
460 |
max_len, min_len = 448, 224
|
461 |
shortest_edge = int(min(max_len / aspect_ratio, min_len))
|
462 |
+
image = processor.preprocess(
|
463 |
+
image,
|
464 |
+
return_tensors="pt",
|
465 |
+
do_center_crop=False,
|
466 |
+
size={"shortest_edge": shortest_edge},
|
467 |
+
)["pixel_values"][0]
|
468 |
+
elif self.multimodal_cfg["image_aspect_ratio"] == "pad":
|
469 |
+
|
470 |
def expand2square(pil_img, background_color):
|
471 |
width, height = pil_img.size
|
472 |
if width == height:
|
473 |
return pil_img
|
474 |
elif width > height:
|
475 |
+
result = Image.new(
|
476 |
+
pil_img.mode, (width, width), background_color
|
477 |
+
)
|
478 |
result.paste(pil_img, (0, (width - height) // 2))
|
479 |
return result
|
480 |
else:
|
481 |
+
result = Image.new(
|
482 |
+
pil_img.mode, (height, height), background_color
|
483 |
+
)
|
484 |
result.paste(pil_img, ((height - width) // 2, 0))
|
485 |
return result
|
486 |
+
|
487 |
+
image = expand2square(
|
488 |
+
image, tuple(int(x * 255) for x in processor.image_mean)
|
489 |
+
)
|
490 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
491 |
+
"pixel_values"
|
492 |
+
][0]
|
493 |
else:
|
494 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
495 |
+
"pixel_values"
|
496 |
+
][0]
|
497 |
+
cur_token_len = (image.shape[1] // 14) * (
|
498 |
+
image.shape[2] // 14
|
499 |
+
) # FIXME: 14 is hardcoded patch size
|
500 |
sources = preprocess_multimodal(
|
501 |
copy.deepcopy([e["conversations"] for e in sources]),
|
502 |
+
self.multimodal_cfg,
|
503 |
+
cur_token_len,
|
504 |
+
)
|
505 |
else:
|
506 |
sources = copy.deepcopy([e["conversations"] for e in sources])
|
507 |
+
data_dict = preprocess(sources, self.tokenizer)
|
|
|
|
|
508 |
if isinstance(i, int):
|
509 |
+
data_dict = dict(
|
510 |
+
input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
|
511 |
+
)
|
512 |
|
513 |
# image exist in the data
|
514 |
+
if "image" in self.list_data_dict[i]:
|
515 |
+
data_dict["image"] = image
|
516 |
+
elif self.multimodal_cfg["is_multimodal"]:
|
517 |
# image does not exist in the data, but the model is multimodal
|
518 |
+
crop_size = self.multimodal_cfg["image_processor"].crop_size
|
519 |
+
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
|
520 |
return data_dict
|
521 |
|
522 |
|
|
|
527 |
tokenizer: transformers.PreTrainedTokenizer
|
528 |
|
529 |
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
530 |
+
input_ids, labels = tuple(
|
531 |
+
[instance[key] for instance in instances] for key in ("input_ids", "labels")
|
532 |
+
)
|
533 |
input_ids = torch.nn.utils.rnn.pad_sequence(
|
534 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
535 |
+
)
|
536 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
537 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
538 |
+
)
|
|
|
539 |
batch = dict(
|
540 |
input_ids=input_ids,
|
541 |
labels=labels,
|
542 |
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
543 |
)
|
544 |
|
545 |
+
if "image" in instances[0]:
|
546 |
+
images = [instance["image"] for instance in instances]
|
547 |
if all(x is not None and x.shape == images[0].shape for x in images):
|
548 |
+
batch["images"] = torch.stack(images)
|
549 |
else:
|
550 |
+
batch["images"] = images
|
551 |
|
552 |
return batch
|
553 |
|
554 |
|
555 |
+
def make_supervised_data_module(
|
556 |
+
tokenizer: transformers.PreTrainedTokenizer, data_args
|
557 |
+
) -> Dict:
|
558 |
"""Make dataset and collator for supervised fine-tuning."""
|
559 |
+
dataset_cls = (
|
560 |
+
LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
|
561 |
+
)
|
562 |
+
train_dataset = dataset_cls(
|
563 |
+
tokenizer=tokenizer,
|
564 |
+
data_path=data_args.data_path,
|
565 |
+
multimodal_cfg=dict(
|
566 |
+
is_multimodal=data_args.is_multimodal,
|
567 |
+
sep_image_conv_front=data_args.sep_image_conv_front,
|
568 |
+
image_token_len=data_args.image_token_len,
|
569 |
+
image_folder=data_args.image_folder,
|
570 |
+
image_aspect_ratio=data_args.image_aspect_ratio,
|
571 |
+
use_im_start_end=getattr(data_args, "mm_use_im_start_end", False),
|
572 |
+
image_processor=getattr(data_args, "image_processor", None),
|
573 |
+
),
|
574 |
+
)
|
575 |
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
576 |
+
return dict(
|
577 |
+
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
|
578 |
+
)
|
579 |
|
580 |
|
581 |
def train():
|
582 |
parser = transformers.HfArgumentParser(
|
583 |
+
(ModelArguments, DataArguments, TrainingArguments)
|
584 |
+
)
|
585 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
586 |
|
587 |
if model_args.vision_tower is not None:
|
588 |
+
if "mpt" in model_args.model_name_or_path:
|
589 |
model = LlavaMPTForCausalLM.from_pretrained(
|
590 |
model_args.model_name_or_path,
|
591 |
cache_dir=training_args.cache_dir,
|
|
|
605 |
if model_args.freeze_backbone:
|
606 |
model.model.requires_grad_(False)
|
607 |
|
608 |
+
if "mpt" in model_args.model_name_or_path:
|
609 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
610 |
model_args.model_name_or_path,
|
611 |
cache_dir=training_args.cache_dir,
|
612 |
model_max_length=training_args.model_max_length,
|
613 |
+
padding_side="right",
|
614 |
)
|
615 |
else:
|
616 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
629 |
model=model,
|
630 |
)
|
631 |
if "llama" in model_args.model_name_or_path:
|
632 |
+
tokenizer.add_special_tokens(
|
633 |
+
{
|
634 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
635 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
636 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
637 |
+
}
|
638 |
+
)
|
639 |
else:
|
640 |
tokenizer.pad_token = tokenizer.unk_token
|
641 |
if "mpt" in model_args.model_name_or_path:
|
642 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
643 |
+
"mpt"
|
644 |
+
]
|
645 |
else:
|
646 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
647 |
+
"vicuna_v1_1"
|
648 |
+
]
|
649 |
|
650 |
if model_args.vision_tower is not None:
|
651 |
model_vision_dict = model.get_model().initialize_vision_modules(
|
652 |
vision_tower=model_args.vision_tower,
|
653 |
mm_vision_select_layer=model_args.mm_vision_select_layer,
|
654 |
+
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
|
655 |
)
|
656 |
dtype = torch.float32
|
657 |
if training_args.fp16:
|
|
|
659 |
if training_args.bf16:
|
660 |
dtype = torch.bfloat16
|
661 |
model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device)
|
662 |
+
vision_config = model_vision_dict["vision_config"]
|
663 |
|
664 |
+
data_args.image_token_len = model_vision_dict["image_token_len"]
|
665 |
+
data_args.image_processor = model_vision_dict["image_processor"]
|
666 |
data_args.is_multimodal = True
|
667 |
|
668 |
+
model.config.tune_mm_mlp_adapter = (
|
669 |
+
training_args.tune_mm_mlp_adapter
|
670 |
+
) = model_args.tune_mm_mlp_adapter
|
671 |
if model_args.tune_mm_mlp_adapter:
|
672 |
model.requires_grad_(False)
|
673 |
for p in model.get_model().mm_projector.parameters():
|
|
|
678 |
for p in model.get_model().mm_projector.parameters():
|
679 |
p.requires_grad = False
|
680 |
|
681 |
+
model.config.mm_use_im_start_end = (
|
682 |
+
data_args.mm_use_im_start_end
|
683 |
+
) = model_args.mm_use_im_start_end
|
684 |
+
vision_config.use_im_start_end = (
|
685 |
+
training_args.use_im_start_end
|
686 |
+
) = model_args.mm_use_im_start_end
|
687 |
model.config.sep_image_conv_front = data_args.sep_image_conv_front
|
688 |
+
model.initialize_vision_tokenizer(
|
689 |
+
mm_use_im_start_end=model_args.mm_use_im_start_end,
|
690 |
+
tokenizer=tokenizer,
|
691 |
+
device=training_args.device,
|
692 |
+
tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter,
|
693 |
+
pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
|
694 |
+
)
|
695 |
|
696 |
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
|
697 |
if len(params_no_grad) > 0:
|
698 |
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
|
699 |
if len(params_no_grad) < 10:
|
700 |
+
print(
|
701 |
+
"[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}".format(
|
702 |
+
len(params_no_grad), params_no_grad
|
703 |
+
)
|
704 |
+
)
|
705 |
else:
|
706 |
+
print(
|
707 |
+
"[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)".format(
|
708 |
+
len(params_no_grad), ", ".join(params_no_grad[:10])
|
709 |
+
)
|
710 |
+
)
|
711 |
+
print(
|
712 |
+
"[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental."
|
713 |
+
)
|
714 |
+
print(
|
715 |
+
"[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
|
716 |
+
)
|
717 |
+
|
718 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import \
|
719 |
+
FullyShardedDataParallel as FSDP
|
720 |
|
|
|
721 |
def patch_FSDP_use_orig_params(func):
|
722 |
def wrap_func(*args, **kwargs):
|
723 |
+
use_orig_params = kwargs.pop("use_orig_params", True)
|
724 |
return func(*args, **kwargs, use_orig_params=use_orig_params)
|
725 |
+
|
726 |
return wrap_func
|
727 |
|
728 |
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
|
729 |
|
730 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
731 |
+
trainer = LLaVATrainer(
|
732 |
+
model=model, tokenizer=tokenizer, args=training_args, **data_module
|
733 |
+
)
|
|
|
|
|
734 |
|
735 |
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
736 |
trainer.train(resume_from_checkpoint=True)
|
737 |
else:
|
738 |
trainer.train()
|
739 |
trainer.save_state()
|
740 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
|
|
741 |
|
742 |
|
743 |
if __name__ == "__main__":
|
model/llava/train/train_mem.py
CHANGED
@@ -3,7 +3,8 @@
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
-
from llava.train.llama_flash_attn_monkey_patch import
|
|
|
7 |
|
8 |
replace_llama_attn_with_flash_attn()
|
9 |
|
|
|
3 |
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
|
4 |
|
5 |
# Need to call this before importing transformers.
|
6 |
+
from llava.train.llama_flash_attn_monkey_patch import \
|
7 |
+
replace_llama_attn_with_flash_attn
|
8 |
|
9 |
replace_llama_attn_with_flash_attn()
|
10 |
|
model/llava/utils.py
CHANGED
@@ -5,11 +5,14 @@ import os
|
|
5 |
import sys
|
6 |
|
7 |
import requests
|
8 |
-
|
9 |
from llava.constants import LOGDIR
|
10 |
|
11 |
-
server_error_msg =
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
handler = None
|
15 |
|
@@ -47,7 +50,8 @@ def build_logger(logger_name, logger_filename):
|
|
47 |
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
-
filename, when=
|
|
|
51 |
handler.setFormatter(formatter)
|
52 |
|
53 |
for name, item in logging.root.manager.loggerDict.items():
|
@@ -61,33 +65,34 @@ class StreamToLogger(object):
|
|
61 |
"""
|
62 |
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
"""
|
|
|
64 |
def __init__(self, logger, log_level=logging.INFO):
|
65 |
self.terminal = sys.stdout
|
66 |
self.logger = logger
|
67 |
self.log_level = log_level
|
68 |
-
self.linebuf =
|
69 |
|
70 |
def __getattr__(self, attr):
|
71 |
return getattr(self.terminal, attr)
|
72 |
|
73 |
def write(self, buf):
|
74 |
temp_linebuf = self.linebuf + buf
|
75 |
-
self.linebuf =
|
76 |
for line in temp_linebuf.splitlines(True):
|
77 |
# From the io.TextIOWrapper docs:
|
78 |
# On output, if newline is None, any '\n' characters written
|
79 |
# are translated to the system default line separator.
|
80 |
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
# translates them so this is still cross platform.
|
82 |
-
if line[-1] ==
|
83 |
self.logger.log(self.log_level, line.rstrip())
|
84 |
else:
|
85 |
self.linebuf += line
|
86 |
|
87 |
def flush(self):
|
88 |
-
if self.linebuf !=
|
89 |
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
-
self.linebuf =
|
91 |
|
92 |
|
93 |
def disable_torch_init():
|
@@ -95,6 +100,7 @@ def disable_torch_init():
|
|
95 |
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
"""
|
97 |
import torch
|
|
|
98 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
|
@@ -104,8 +110,10 @@ def violates_moderation(text):
|
|
104 |
Check whether the text violates OpenAI moderation API.
|
105 |
"""
|
106 |
url = "https://api.openai.com/v1/moderations"
|
107 |
-
headers = {
|
108 |
-
|
|
|
|
|
109 |
text = text.replace("\n", "")
|
110 |
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
data = data.encode("utf-8")
|
|
|
5 |
import sys
|
6 |
|
7 |
import requests
|
|
|
8 |
from llava.constants import LOGDIR
|
9 |
|
10 |
+
server_error_msg = (
|
11 |
+
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
)
|
13 |
+
moderation_msg = (
|
14 |
+
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
15 |
+
)
|
16 |
|
17 |
handler = None
|
18 |
|
|
|
50 |
os.makedirs(LOGDIR, exist_ok=True)
|
51 |
filename = os.path.join(LOGDIR, logger_filename)
|
52 |
handler = logging.handlers.TimedRotatingFileHandler(
|
53 |
+
filename, when="D", utc=True
|
54 |
+
)
|
55 |
handler.setFormatter(formatter)
|
56 |
|
57 |
for name, item in logging.root.manager.loggerDict.items():
|
|
|
65 |
"""
|
66 |
Fake file-like stream object that redirects writes to a logger instance.
|
67 |
"""
|
68 |
+
|
69 |
def __init__(self, logger, log_level=logging.INFO):
|
70 |
self.terminal = sys.stdout
|
71 |
self.logger = logger
|
72 |
self.log_level = log_level
|
73 |
+
self.linebuf = ""
|
74 |
|
75 |
def __getattr__(self, attr):
|
76 |
return getattr(self.terminal, attr)
|
77 |
|
78 |
def write(self, buf):
|
79 |
temp_linebuf = self.linebuf + buf
|
80 |
+
self.linebuf = ""
|
81 |
for line in temp_linebuf.splitlines(True):
|
82 |
# From the io.TextIOWrapper docs:
|
83 |
# On output, if newline is None, any '\n' characters written
|
84 |
# are translated to the system default line separator.
|
85 |
# By default sys.stdout.write() expects '\n' newlines and then
|
86 |
# translates them so this is still cross platform.
|
87 |
+
if line[-1] == "\n":
|
88 |
self.logger.log(self.log_level, line.rstrip())
|
89 |
else:
|
90 |
self.linebuf += line
|
91 |
|
92 |
def flush(self):
|
93 |
+
if self.linebuf != "":
|
94 |
self.logger.log(self.log_level, self.linebuf.rstrip())
|
95 |
+
self.linebuf = ""
|
96 |
|
97 |
|
98 |
def disable_torch_init():
|
|
|
100 |
Disable the redundant torch default initialization to accelerate model creation.
|
101 |
"""
|
102 |
import torch
|
103 |
+
|
104 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
105 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
106 |
|
|
|
110 |
Check whether the text violates OpenAI moderation API.
|
111 |
"""
|
112 |
url = "https://api.openai.com/v1/moderations"
|
113 |
+
headers = {
|
114 |
+
"Content-Type": "application/json",
|
115 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
|
116 |
+
}
|
117 |
text = text.replace("\n", "")
|
118 |
data = "{" + '"input": ' + f'"{text}"' + "}"
|
119 |
data = data.encode("utf-8")
|
model/segment_anything/__init__.py
CHANGED
@@ -4,12 +4,7 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
-
from .build_sam import (
|
8 |
-
build_sam,
|
9 |
-
build_sam_vit_h,
|
10 |
-
build_sam_vit_l,
|
11 |
-
build_sam_vit_b,
|
12 |
-
sam_model_registry,
|
13 |
-
)
|
14 |
-
from .predictor import SamPredictor
|
15 |
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
|
|
|
|
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
8 |
+
from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
|
9 |
+
build_sam_vit_l, sam_model_registry)
|
10 |
+
from .predictor import SamPredictor
|
model/segment_anything/automatic_mask_generator.py
CHANGED
@@ -4,32 +4,21 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
|
11 |
-
from typing import Any, Dict, List, Optional, Tuple
|
12 |
-
|
13 |
from .modeling import Sam
|
14 |
from .predictor import SamPredictor
|
15 |
-
from .utils.amg import (
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
calculate_stability_score,
|
23 |
-
coco_encode_rle,
|
24 |
-
generate_crop_boxes,
|
25 |
-
is_box_near_crop_edge,
|
26 |
-
mask_to_rle_pytorch,
|
27 |
-
remove_small_regions,
|
28 |
-
rle_to_mask,
|
29 |
-
uncrop_boxes_xyxy,
|
30 |
-
uncrop_masks,
|
31 |
-
uncrop_points,
|
32 |
-
)
|
33 |
|
34 |
|
35 |
class SamAutomaticMaskGenerator:
|
@@ -115,7 +104,8 @@ class SamAutomaticMaskGenerator:
|
|
115 |
"coco_rle",
|
116 |
], f"Unknown output_mode {output_mode}."
|
117 |
if output_mode == "coco_rle":
|
118 |
-
from pycocotools import
|
|
|
119 |
|
120 |
if min_mask_region_area > 0:
|
121 |
import cv2 # type: ignore # noqa: F401
|
@@ -172,7 +162,9 @@ class SamAutomaticMaskGenerator:
|
|
172 |
|
173 |
# Encode masks
|
174 |
if self.output_mode == "coco_rle":
|
175 |
-
mask_data["segmentations"] = [
|
|
|
|
|
176 |
elif self.output_mode == "binary_mask":
|
177 |
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
178 |
else:
|
@@ -242,7 +234,9 @@ class SamAutomaticMaskGenerator:
|
|
242 |
# Generate masks for this crop in batches
|
243 |
data = MaskData()
|
244 |
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
245 |
-
batch_data = self._process_batch(
|
|
|
|
|
246 |
data.cat(batch_data)
|
247 |
del batch_data
|
248 |
self.predictor.reset_image()
|
@@ -275,7 +269,9 @@ class SamAutomaticMaskGenerator:
|
|
275 |
# Run model on this batch
|
276 |
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
277 |
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
278 |
-
in_labels = torch.ones(
|
|
|
|
|
279 |
masks, iou_preds, _ = self.predictor.predict_torch(
|
280 |
in_points[:, None, :],
|
281 |
in_labels[:, None],
|
@@ -298,7 +294,9 @@ class SamAutomaticMaskGenerator:
|
|
298 |
|
299 |
# Calculate stability score
|
300 |
data["stability_score"] = calculate_stability_score(
|
301 |
-
data["masks"],
|
|
|
|
|
302 |
)
|
303 |
if self.stability_score_thresh > 0.0:
|
304 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
@@ -309,7 +307,9 @@ class SamAutomaticMaskGenerator:
|
|
309 |
data["boxes"] = batched_mask_to_box(data["masks"])
|
310 |
|
311 |
# Filter boxes that touch crop boundaries
|
312 |
-
keep_mask = ~is_box_near_crop_edge(
|
|
|
|
|
313 |
if not torch.all(keep_mask):
|
314 |
data.filter(keep_mask)
|
315 |
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
8 |
+
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
12 |
|
|
|
|
|
13 |
from .modeling import Sam
|
14 |
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (MaskData, area_from_rle, batch_iterator,
|
16 |
+
batched_mask_to_box, box_xyxy_to_xywh,
|
17 |
+
build_all_layer_point_grids, calculate_stability_score,
|
18 |
+
coco_encode_rle, generate_crop_boxes,
|
19 |
+
is_box_near_crop_edge, mask_to_rle_pytorch,
|
20 |
+
remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
|
21 |
+
uncrop_masks, uncrop_points)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
class SamAutomaticMaskGenerator:
|
|
|
104 |
"coco_rle",
|
105 |
], f"Unknown output_mode {output_mode}."
|
106 |
if output_mode == "coco_rle":
|
107 |
+
from pycocotools import \
|
108 |
+
mask as mask_utils # type: ignore # noqa: F401
|
109 |
|
110 |
if min_mask_region_area > 0:
|
111 |
import cv2 # type: ignore # noqa: F401
|
|
|
162 |
|
163 |
# Encode masks
|
164 |
if self.output_mode == "coco_rle":
|
165 |
+
mask_data["segmentations"] = [
|
166 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
167 |
+
]
|
168 |
elif self.output_mode == "binary_mask":
|
169 |
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
170 |
else:
|
|
|
234 |
# Generate masks for this crop in batches
|
235 |
data = MaskData()
|
236 |
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
237 |
+
batch_data = self._process_batch(
|
238 |
+
points, cropped_im_size, crop_box, orig_size
|
239 |
+
)
|
240 |
data.cat(batch_data)
|
241 |
del batch_data
|
242 |
self.predictor.reset_image()
|
|
|
269 |
# Run model on this batch
|
270 |
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
271 |
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
272 |
+
in_labels = torch.ones(
|
273 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
274 |
+
)
|
275 |
masks, iou_preds, _ = self.predictor.predict_torch(
|
276 |
in_points[:, None, :],
|
277 |
in_labels[:, None],
|
|
|
294 |
|
295 |
# Calculate stability score
|
296 |
data["stability_score"] = calculate_stability_score(
|
297 |
+
data["masks"],
|
298 |
+
self.predictor.model.mask_threshold,
|
299 |
+
self.stability_score_offset,
|
300 |
)
|
301 |
if self.stability_score_thresh > 0.0:
|
302 |
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
|
|
307 |
data["boxes"] = batched_mask_to_box(data["masks"])
|
308 |
|
309 |
# Filter boxes that touch crop boundaries
|
310 |
+
keep_mask = ~is_box_near_crop_edge(
|
311 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
312 |
+
)
|
313 |
if not torch.all(keep_mask):
|
314 |
data.filter(keep_mask)
|
315 |
|