x-lai commited on
Commit
3d9fba4
·
1 Parent(s): 11d7ed8

Release training script

Browse files

Former-commit-id: 6f951959fdf50617a5ad55be75bb9139e63fa04b

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -1
  2. README.md +79 -2
  3. chat.py +214 -144
  4. model/LISA.py +471 -194
  5. model/llava/conversation.py +50 -28
  6. model/llava/eval/eval_gpt_review.py +57 -43
  7. model/llava/eval/eval_gpt_review_visual.py +67 -48
  8. model/llava/eval/eval_science_qa.py +44 -40
  9. model/llava/eval/eval_science_qa_gpt4.py +36 -32
  10. model/llava/eval/eval_science_qa_gpt4_requery.py +85 -70
  11. model/llava/eval/generate_webpage_data_from_table.py +46 -38
  12. model/llava/eval/model_qa.py +30 -17
  13. model/llava/eval/model_vqa.py +113 -50
  14. model/llava/eval/model_vqa_science.py +136 -69
  15. model/llava/eval/qa_baseline_gpt35.py +35 -27
  16. model/llava/eval/run_llava.py +78 -33
  17. model/llava/eval/run_llava_batch.py +240 -69
  18. model/llava/eval/run_llava_batch_v2.py +240 -69
  19. model/llava/eval/run_llava_batch_v3.py +244 -70
  20. model/llava/eval/summarize_gpt_review.py +12 -10
  21. model/llava/model/__init__.py +2 -2
  22. model/llava/model/apply_delta.py +16 -8
  23. model/llava/model/consolidate.py +4 -2
  24. model/llava/model/llava.py +211 -65
  25. model/llava/model/llava_mpt.py +284 -76
  26. model/llava/model/make_delta.py +19 -8
  27. model/llava/model/mpt/adapt_tokenizer.py +11 -6
  28. model/llava/model/mpt/attention.py +289 -70
  29. model/llava/model/mpt/blocks.py +60 -11
  30. model/llava/model/mpt/configuration_mpt.py +104 -28
  31. model/llava/model/mpt/hf_prefixlm_converter.py +439 -104
  32. model/llava/model/mpt/meta_init_context.py +27 -10
  33. model/llava/model/mpt/modeling_mpt.py +278 -84
  34. model/llava/model/mpt/norm.py +67 -17
  35. model/llava/model/mpt/param_init_fns.py +290 -52
  36. model/llava/model/utils.py +21 -10
  37. model/llava/serve/cli.py +25 -20
  38. model/llava/serve/controller.py +35 -25
  39. model/llava/serve/gradio_css.py +3 -5
  40. model/llava/serve/gradio_patch.py +6 -7
  41. model/llava/serve/gradio_web_server.py +208 -92
  42. model/llava/serve/model_worker.py +170 -82
  43. model/llava/serve/test_message.py +18 -9
  44. model/llava/train/llama_flash_attn_monkey_patch.py +53 -41
  45. model/llava/train/llava_trainer.py +13 -9
  46. model/llava/train/train.py +211 -138
  47. model/llava/train/train_mem.py +2 -1
  48. model/llava/utils.py +19 -11
  49. model/segment_anything/__init__.py +3 -8
  50. model/segment_anything/automatic_mask_generator.py +26 -26
.gitignore CHANGED
@@ -1 +1,4 @@
1
- **/__pycache__
 
 
 
 
1
+ **/__pycache__
2
+ runs/
3
+ .vscode/
4
+
README.md CHANGED
@@ -47,6 +47,83 @@ For more details, please refer to the [paper](https://arxiv.org/abs/2308.00692).
47
  ```
48
  pip install -r requirements.txt
49
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  ## Inference
52
  To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explainatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explainatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
@@ -93,9 +170,9 @@ Important keys contained in JSON files:
93
 
94
  The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
95
 
96
- We provide a <a href="https://github.com/dvlab-research/LISA/blob/main/utils/data_proc_demo.py">**script**</a> that demonstrates how to process the annotations:
97
  ```
98
- python3 utils/data_proc_demo.py
99
  ```
100
 
101
  Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
 
47
  ```
48
  pip install -r requirements.txt
49
  ```
50
+
51
+ ## Training
52
+ ### Training Data Preparation
53
+ The training data consists of 4 types of data:
54
+
55
+ 1. Semantic segmentation datasets: [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), [COCO-Stuff](https://github.com/nightrome/cocostuff#downloads), [Mapillary](https://www.mapillary.com/dataset/vistas), [PACO-LVIS](https://github.com/facebookresearch/paco/tree/main#dataset-setup), [PASCAL-Part](http://roozbehm.info/pascal-parts/pascal-parts.html)
56
+
57
+ 2. Referring segmentation datasets: refCOCO, refCOCO+, refCOCOg [\[Download\]](https://github.com/lichengunc/refer#download)
58
+
59
+ 3. Visual Question Answering dataset: [LLaVA-Instruct-150k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json)
60
+
61
+ 4. Reasoning segmentation dataset: [ReasonSeg](https://github.com/dvlab-research/LISA#dataset)
62
+
63
+ Download them from the above links, and organize them as follows.
64
+
65
+ ```
66
+ ├── dataset
67
+ │   ├── ade20k
68
+ │   │   ├── annotations
69
+ │   │   └── images
70
+ │   ├── coco
71
+ │   │   └── train2017
72
+ │   ├── cocostuff
73
+ │   │   ├── annotations
74
+ │   │   └── train2017
75
+ │   ├── llava_dataset
76
+ │   │   └── llava_instruct_150k.json
77
+ │   ├── mapillary
78
+ │   │   ├── config_v2.0.json
79
+ │   │   ├── testing
80
+ │   │   ├── training
81
+ │   │   └── validation
82
+ │   ├── reason_seg
83
+ │   │   └── ReasonSeg
84
+ │   │   ├── train
85
+ │   │   ├── val
86
+ │   │   └── explanatory
87
+ │   ├── refer_seg
88
+ │   │   ├── images
89
+ │   │   | ├── saiapr_tc-12
90
+ │   │   | └── mscoco
91
+ │   │   | └── images
92
+ │   │   | └── train2014
93
+ │   │   ├── refclef
94
+ │   │   ├── refcoco
95
+ │   │   ├── refcoco+
96
+ │   │   └── refcocog
97
+ │   └── vlpart
98
+ │   ├── paco
99
+ │ │ └── annotations
100
+ │   └── pascal_part
101
+ │   ├── train.json
102
+ │ └── VOCdevkit
103
+ ```
104
+
105
+ ### Pre-trained weights
106
+
107
+ #### LLaVA
108
+ To train LISA-7B or 13B, you need to follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. Typically, we use the final weights `LLaVA-Lightning-7B-v1-1` and `LLaVA-13B-v1-1` merged from `liuhaotian/LLaVA-Lightning-7B-delta-v1-1` and `liuhaotian/LLaVA-13b-delta-v1-1`, respectively. For Llama2, we can directly use the LLaVA full weights `liuhaotian/llava-llama-2-13b-chat-lightning-preview`.
109
+
110
+ #### SAM ViT-H weights
111
+ Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
112
+
113
+ ### Training
114
+ ```
115
+ deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b"
116
+ ```
117
+ When training is finished, to get the full model weight:
118
+ ```
119
+ cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
120
+ ```
121
+
122
+ ### Validation
123
+ ```
124
+ deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b" --weight='PATH_TO_pytorch_model.bin' --eval_only
125
+ ```
126
+
127
 
128
  ## Inference
129
  To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explainatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explainatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)
 
170
 
171
  The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.
172
 
173
+ We provide a <a href="https://github.com/dvlab-research/LISA/blob/main/utils/data_processing.py">**script**</a> that demonstrates how to process the annotations:
174
  ```
175
+ python3 utils/data_processing.py
176
  ```
177
 
178
  Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model.
chat.py CHANGED
@@ -1,38 +1,48 @@
1
- import sys
2
  import os
 
 
3
  import cv2
4
- import argparse
5
- import torch
6
- import transformers
7
  import numpy as np
 
8
  import torch.nn.functional as F
9
-
10
  from transformers import AutoTokenizer, CLIPImageProcessor
11
 
12
  from model.LISA import LISA
13
- from utils.conversation import get_default_conv_template
14
  from model.segment_anything.utils.transforms import ResizeLongestSide
 
 
15
 
16
  def parse_args(args):
17
- parser = argparse.ArgumentParser(description='LISA chat')
18
- parser.add_argument('--version', default='xinlai/LISA-13B-llama2-v0')
19
- parser.add_argument('--vis_save_path', default='./vis_output', type=str)
20
- parser.add_argument('--precision', default='bf16', type=str, choices=['fp32', 'bf16', 'fp16'], help="precision for inference")
21
- parser.add_argument('--image-size', default=1024, type=int, help='image size')
22
- parser.add_argument('--model-max-length', default=512, type=int)
23
- parser.add_argument('--lora-r', default=-1, type=int)
24
- parser.add_argument('--vision-tower', default='openai/clip-vit-large-patch14', type=str)
25
- parser.add_argument('--local-rank', default=0, type=int, help='node rank')
26
- parser.add_argument('--load_in_8bit', action='store_true', default=False)
27
- parser.add_argument('--load_in_4bit', action='store_true', default=False)
28
- return parser.parse_args(args)
29
-
30
-
31
- def preprocess(x,
32
- pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
 
 
 
 
 
 
 
 
 
33
  pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
34
- img_size=1024
35
- ) -> torch.Tensor:
36
  """Normalize pixel values and pad to a square input."""
37
  # Normalize colors
38
  x = (x - pixel_mean) / pixel_std
@@ -45,125 +55,185 @@ def preprocess(x,
45
 
46
 
47
  def main(args):
48
- args = parse_args(args)
49
- os.makedirs(args.vis_save_path, exist_ok=True)
50
-
51
- # Create model
52
- tokenizer = transformers.AutoTokenizer.from_pretrained(
53
- args.version,
54
- cache_dir=None,
55
- model_max_length=args.model_max_length,
56
- padding_side="right",
57
- use_fast=False,
58
- )
59
- tokenizer.pad_token = tokenizer.unk_token
60
- num_added_tokens = tokenizer.add_tokens('[SEG]')
61
- ret_token_idx = tokenizer('[SEG]', add_special_tokens=False).input_ids
62
- args.seg_token_idx = ret_token_idx[0]
63
-
64
- model = LISA(
65
- args.local_rank,
66
- args.seg_token_idx,
67
- tokenizer,
68
- args.version,
69
- args.lora_r,
70
- args.precision,
71
- load_in_8bit=args.load_in_8bit,
72
- load_in_4bit=args.load_in_4bit,
73
- )
74
-
75
- weight = {}
76
- visual_model_weight = torch.load(os.path.join(args.version, "pytorch_model-visual_model.bin"))
77
- text_hidden_fcs_weight = torch.load(os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin"))
78
- weight.update(visual_model_weight)
79
- weight.update(text_hidden_fcs_weight)
80
- missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False)
81
-
82
- if args.precision == 'bf16':
83
- model = model.bfloat16().cuda()
84
- elif args.precision == 'fp16':
85
- import deepspeed
86
- model_engine = deepspeed.init_inference(model=model,
87
- dtype=torch.half,
88
- replace_with_kernel_inject=True,
89
- replace_method="auto",
90
  )
91
- model = model_engine.module
92
- else:
93
- model = model.float().cuda()
94
-
95
- DEFAULT_IMAGE_TOKEN = "<image>"
96
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
97
- DEFAULT_IM_START_TOKEN = "<im_start>"
98
- DEFAULT_IM_END_TOKEN = "<im_end>"
99
- image_token_len = 256
100
-
101
- clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower)
102
- transform = ResizeLongestSide(args.image_size)
103
-
104
- while True:
105
-
106
- conv = get_default_conv_template("vicuna").copy()
107
- conv.messages = []
108
-
109
- prompt = input("Please input your prompt: ")
110
- prompt = DEFAULT_IMAGE_TOKEN + " " + prompt
111
- replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
112
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
113
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
114
-
115
- conv.append_message(conv.roles[0], prompt)
116
- conv.append_message(conv.roles[1], "")
117
- prompt = conv.get_prompt()
118
-
119
- image_path = input("Please input the image path: ")
120
- if not os.path.exists(image_path):
121
- print("File not found in {}".format(image_path))
122
- continue
123
-
124
- image = cv2.imread(image_path)
125
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
126
- original_size_list = [image.shape[:2]]
127
- if args.precision == 'bf16':
128
- images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().bfloat16()
129
- elif args.precision == 'fp16':
130
- images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().half()
131
- else:
132
- images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().float()
133
- images = transform.apply_image(image)
134
- resize_list = [images.shape[:2]]
135
- if args.precision == 'bf16':
136
- images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().bfloat16()
137
- elif args.precision == 'fp16':
138
- images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().half()
139
  else:
140
- images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().float()
141
-
142
- input_ids = tokenizer(prompt).input_ids
143
- input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda()
144
- output_ids, pred_masks = model.evaluate(images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer)
145
- text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
146
- text_output = text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "").replace("\n", "").replace(" ", "")
147
-
148
- print("text_output: ", text_output)
149
- for i, pred_mask in enumerate(pred_masks):
150
-
151
- if pred_mask.shape[0] == 0:
152
- continue
153
-
154
- pred_mask = pred_mask.detach().cpu().numpy()[0]
155
- pred_mask = (pred_mask > 0)
156
-
157
- save_path = "{}/{}_mask_{}.jpg".format(args.vis_save_path, image_path.split("/")[-1].split(".")[0], i)
158
- cv2.imwrite(save_path, pred_mask * 100)
159
- print("{} has been saved.".format(save_path))
160
-
161
- save_path = "{}/{}_masked_img_{}.jpg".format(args.vis_save_path, image_path.split("/")[-1].split(".")[0], i)
162
- save_img = image.copy()
163
- save_img[pred_mask] = (image * 0.5 + pred_mask[:,:,None].astype(np.uint8) * np.array([255,0,0]) * 0.5)[pred_mask]
164
- save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
165
- cv2.imwrite(save_path, save_img)
166
- print("{} has been saved.".format(save_path))
167
-
168
- if __name__ == '__main__':
169
- main(sys.argv[1:])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import os
3
+ import sys
4
+
5
  import cv2
 
 
 
6
  import numpy as np
7
+ import torch
8
  import torch.nn.functional as F
9
+ import transformers
10
  from transformers import AutoTokenizer, CLIPImageProcessor
11
 
12
  from model.LISA import LISA
 
13
  from model.segment_anything.utils.transforms import ResizeLongestSide
14
+ from utils.conversation import get_default_conv_template
15
+
16
 
17
  def parse_args(args):
18
+ parser = argparse.ArgumentParser(description="LISA chat")
19
+ parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v0")
20
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
21
+ parser.add_argument(
22
+ "--precision",
23
+ default="bf16",
24
+ type=str,
25
+ choices=["fp32", "bf16", "fp16"],
26
+ help="precision for inference",
27
+ )
28
+ parser.add_argument("--image-size", default=1024, type=int, help="image size")
29
+ parser.add_argument("--model-max-length", default=512, type=int)
30
+ parser.add_argument("--lora-r", default=-1, type=int)
31
+ parser.add_argument(
32
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
33
+ )
34
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
35
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
36
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
37
+ return parser.parse_args(args)
38
+
39
+
40
+ def preprocess(
41
+ x,
42
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
43
  pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
44
+ img_size=1024,
45
+ ) -> torch.Tensor:
46
  """Normalize pixel values and pad to a square input."""
47
  # Normalize colors
48
  x = (x - pixel_mean) / pixel_std
 
55
 
56
 
57
  def main(args):
58
+ args = parse_args(args)
59
+ os.makedirs(args.vis_save_path, exist_ok=True)
60
+
61
+ # Create model
62
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
63
+ args.version,
64
+ cache_dir=None,
65
+ model_max_length=args.model_max_length,
66
+ padding_side="right",
67
+ use_fast=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
+ tokenizer.pad_token = tokenizer.unk_token
70
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
71
+ ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids
72
+ args.seg_token_idx = ret_token_idx[0]
73
+
74
+ model = LISA(
75
+ args.local_rank,
76
+ args.seg_token_idx,
77
+ tokenizer,
78
+ args.version,
79
+ args.lora_r,
80
+ args.precision,
81
+ load_in_8bit=args.load_in_8bit,
82
+ load_in_4bit=args.load_in_4bit,
83
+ )
84
+
85
+ weight = {}
86
+ visual_model_weight = torch.load(
87
+ os.path.join(args.version, "pytorch_model-visual_model.bin")
88
+ )
89
+ text_hidden_fcs_weight = torch.load(
90
+ os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin")
91
+ )
92
+ weight.update(visual_model_weight)
93
+ weight.update(text_hidden_fcs_weight)
94
+ missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False)
95
+
96
+ if args.precision == "bf16":
97
+ model = model.bfloat16().cuda()
98
+ elif args.precision == "fp16":
99
+ import deepspeed
100
+
101
+ model_engine = deepspeed.init_inference(
102
+ model=model,
103
+ dtype=torch.half,
104
+ replace_with_kernel_inject=True,
105
+ replace_method="auto",
106
+ )
107
+ model = model_engine.module
 
 
 
 
 
 
 
 
 
108
  else:
109
+ model = model.float().cuda()
110
+
111
+ DEFAULT_IMAGE_TOKEN = "<image>"
112
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
113
+ DEFAULT_IM_START_TOKEN = "<im_start>"
114
+ DEFAULT_IM_END_TOKEN = "<im_end>"
115
+ image_token_len = 256
116
+
117
+ clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower)
118
+ transform = ResizeLongestSide(args.image_size)
119
+
120
+ while True:
121
+ conv = get_default_conv_template("vicuna").copy()
122
+ conv.messages = []
123
+
124
+ prompt = input("Please input your prompt: ")
125
+ prompt = DEFAULT_IMAGE_TOKEN + " " + prompt
126
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
127
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
128
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
129
+
130
+ conv.append_message(conv.roles[0], prompt)
131
+ conv.append_message(conv.roles[1], "")
132
+ prompt = conv.get_prompt()
133
+
134
+ image_path = input("Please input the image path: ")
135
+ if not os.path.exists(image_path):
136
+ print("File not found in {}".format(image_path))
137
+ continue
138
+
139
+ image = cv2.imread(image_path)
140
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
141
+ original_size_list = [image.shape[:2]]
142
+ if args.precision == "bf16":
143
+ images_clip = (
144
+ clip_image_processor.preprocess(image, return_tensors="pt")[
145
+ "pixel_values"
146
+ ][0]
147
+ .unsqueeze(0)
148
+ .cuda()
149
+ .bfloat16()
150
+ )
151
+ elif args.precision == "fp16":
152
+ images_clip = (
153
+ clip_image_processor.preprocess(image, return_tensors="pt")[
154
+ "pixel_values"
155
+ ][0]
156
+ .unsqueeze(0)
157
+ .cuda()
158
+ .half()
159
+ )
160
+ else:
161
+ images_clip = (
162
+ clip_image_processor.preprocess(image, return_tensors="pt")[
163
+ "pixel_values"
164
+ ][0]
165
+ .unsqueeze(0)
166
+ .cuda()
167
+ .float()
168
+ )
169
+ images = transform.apply_image(image)
170
+ resize_list = [images.shape[:2]]
171
+ if args.precision == "bf16":
172
+ images = (
173
+ preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
174
+ .unsqueeze(0)
175
+ .cuda()
176
+ .bfloat16()
177
+ )
178
+ elif args.precision == "fp16":
179
+ images = (
180
+ preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
181
+ .unsqueeze(0)
182
+ .cuda()
183
+ .half()
184
+ )
185
+ else:
186
+ images = (
187
+ preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
188
+ .unsqueeze(0)
189
+ .cuda()
190
+ .float()
191
+ )
192
+
193
+ input_ids = tokenizer(prompt).input_ids
194
+ input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda()
195
+ output_ids, pred_masks = model.evaluate(
196
+ images_clip,
197
+ images,
198
+ input_ids,
199
+ resize_list,
200
+ original_size_list,
201
+ max_new_tokens=512,
202
+ tokenizer=tokenizer,
203
+ )
204
+ text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
205
+ text_output = (
206
+ text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "")
207
+ .replace("\n", "")
208
+ .replace(" ", "")
209
+ )
210
+
211
+ print("text_output: ", text_output)
212
+ for i, pred_mask in enumerate(pred_masks):
213
+ if pred_mask.shape[0] == 0:
214
+ continue
215
+
216
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
217
+ pred_mask = pred_mask > 0
218
+
219
+ save_path = "{}/{}_mask_{}.jpg".format(
220
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
221
+ )
222
+ cv2.imwrite(save_path, pred_mask * 100)
223
+ print("{} has been saved.".format(save_path))
224
+
225
+ save_path = "{}/{}_masked_img_{}.jpg".format(
226
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
227
+ )
228
+ save_img = image.copy()
229
+ save_img[pred_mask] = (
230
+ image * 0.5
231
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
232
+ )[pred_mask]
233
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
234
+ cv2.imwrite(save_path, save_img)
235
+ print("{} has been saved.".format(save_path))
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main(sys.argv[1:])
model/LISA.py CHANGED
@@ -1,213 +1,490 @@
1
- from typing import Callable, List, Optional, Tuple, Union
2
- import json
3
- import glob
4
- import math
5
- import numpy as np
6
- import os
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
- import transformers
11
-
12
- from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig
13
- from peft import (
14
- LoraConfig,
15
- get_peft_model,
16
- get_peft_model_state_dict,
17
- prepare_model_for_int8_training,
18
- set_peft_model_state_dict,
19
- )
20
- from .llava.model.llava import LlavaLlamaForCausalLM
21
- from .segment_anything import build_sam_vit_l, build_sam_vit_h
22
 
23
- DEFAULT_IMAGE_TOKEN = "<image>"
24
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
25
- DEFAULT_IM_START_TOKEN = "<im_start>"
26
- DEFAULT_IM_END_TOKEN = "<im_end>"
27
 
28
- def find_all_linear_names(model):
29
- cls = torch.nn.Linear
30
- lora_module_names = set()
31
- for name, module in model.named_modules():
32
- if isinstance(module, cls):
33
- names = name.split('.')
34
- lora_module_names.add(names[0] if len(names) == 1 else names[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- if 'lm_head' in lora_module_names: # needed for 16-bit
37
- lora_module_names.remove('lm_head')
38
 
39
- if 'mm_projector' in lora_module_names:
40
- lora_module_names.remove('mm_projector')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return sorted(list(lora_module_names))
43
 
44
  class LISA(nn.Module):
45
- def __init__(self,
46
- local_rank,
47
- seg_token_idx,
48
- tokenizer,
49
- llm_version,
50
- lora_r,
51
- precision,
52
- load_in_4bit=False,
53
- load_in_8bit=False,
54
- lora_target_modules=['q_proj', 'v_proj'],
55
- lora_alpha=16,
56
- lora_dropout=0.05,
57
- vision_tower='openai/clip-vit-large-patch14',
58
- mm_vision_select_layer=-2,
59
- freeze_lm=True,
60
- train_mask_decoder=True,
61
- out_dim=256,
62
- ):
63
-
64
- super().__init__()
65
- self.tokenizer = tokenizer
66
- self.image_token = tokenizer.cls_token_id
67
- self.precision = precision
68
-
69
- # LLaVA
70
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
71
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
72
- if precision == "bf16":
73
- self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True)
74
- elif precision == "fp16":
75
- if load_in_4bit:
76
- self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto',
77
- quantization_config=BitsAndBytesConfig(
78
- load_in_4bit=True,
79
- bnb_4bit_compute_dtype=torch.float16,
80
- bnb_4bit_use_double_quant=True,
81
- bnb_4bit_quant_type='nf4'
82
- )
83
- )
84
- elif load_in_8bit:
85
- self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto')
86
- else:
87
- self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True)
88
- else:
89
- self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
90
-
91
- self.lm.enable_input_require_grads()
92
- self.lm.gradient_checkpointing_enable()
93
- self.lm.config.use_cache = False
94
- model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision)
95
- vision_config = model_vision_dict['vision_config']
96
- vision_tower = self.lm.get_model().vision_tower[0]
97
- self.lm.model.config.eos_token_id = tokenizer.eos_token_id
98
- self.lm.model.config.bos_token_id = tokenizer.bos_token_id
99
- self.lm.model.config.pad_token_id = tokenizer.pad_token_id
100
-
101
- if vision_tower.device.type == 'meta':
102
- if precision == 'bf16':
103
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
104
- elif precision == 'fp16':
105
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank)
106
- else:
107
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
108
- self.lm.get_model().vision_tower[0] = vision_tower
109
- else:
110
 
 
 
 
 
 
111
  if precision == "bf16":
112
- vision_tower.to(device='cuda', dtype=torch.bfloat16)
 
 
 
 
 
113
  elif precision == "fp16":
114
- vision_tower.to(device='cuda', dtype=torch.half)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  else:
116
- vision_tower.to(device='cuda', dtype=torch.float32)
117
-
118
- self.lm.config.tune_mm_mlp_adapter = False
119
- self.lm.config.freeze_mm_mlp_adapter = False
120
- self.lm.config.mm_use_im_start_end = True
121
- vision_config.use_im_start_end = True
122
- self.lm.config.sep_image_conv_front = False
123
-
124
- self.lm.initialize_vision_tokenizer(mm_use_im_start_end=True, tokenizer=tokenizer, num_new_tokens=num_new_tokens, device=local_rank, tune_mm_mlp_adapter=False)
125
- if freeze_lm:
126
- for n, param in self.lm.named_parameters():
127
- param.requires_grad = False
128
-
129
- self.llm_version = llm_version
130
-
131
- self.seg_token_idx = seg_token_idx
132
- self.lm.resize_token_embeddings(len(tokenizer))
133
-
134
- for n, p in self.lm.named_parameters():
135
- if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer):
136
- p.requires_grad = True
137
-
138
- # SAM
139
- self.visual_model = build_sam_vit_h(None)
140
- for param in self.visual_model.parameters():
141
- param.requires_grad = False
142
- if train_mask_decoder:
143
- self.visual_model.mask_decoder.train()
144
- for param in self.visual_model.mask_decoder.parameters():
145
- param.requires_grad = True
146
-
147
- # Projection layer
148
- in_dim = self.lm.config.hidden_size
149
- text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)]
150
- self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
151
-
152
- def get_visual_embs(self, pixel_values: torch.FloatTensor):
153
- image_embeddings = self.visual_model.image_encoder(pixel_values)
154
- return image_embeddings
155
-
156
- def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
157
-
158
- with torch.no_grad():
159
- outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
160
- output_hidden_states = outputs.hidden_states[-1]
161
- output_ids = outputs.sequences
162
-
163
- seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
164
-
165
- last_embedding = None
166
- last_output_logit = None
167
- hidden_states = []
168
-
169
- assert len(self.text_hidden_fcs) == 1
170
- hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
171
-
172
- last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
173
- pred_embeddings = last_hidden_state[seg_token_mask]
174
-
175
- seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
176
- seg_token_offset = seg_token_counts.cumsum(-1)
177
- seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
178
-
179
- pred_embeddings_ = []
180
- for i in range(len(seg_token_offset)-1):
181
- start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
182
- pred_embeddings_.append(pred_embeddings[start_i: end_i])
183
- pred_embeddings = pred_embeddings_
184
-
185
- image_embeddings = self.get_visual_embs(images)
186
-
187
- multimask_output = False
188
- pred_masks = []
189
- for i in range(len(pred_embeddings)):
190
- sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
191
- points=None,
192
- boxes=None,
193
- masks=None,
194
- text_embeds=pred_embeddings[i].unsqueeze(1),
195
  )
 
 
 
 
 
196
 
197
- sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
198
- low_res_masks, iou_predictions = self.visual_model.mask_decoder(
199
- image_embeddings=image_embeddings[i].unsqueeze(0),
200
- image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
201
- sparse_prompt_embeddings=sparse_embeddings,
202
- dense_prompt_embeddings=dense_embeddings,
203
- multimask_output=multimask_output,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- pred_mask = self.visual_model.postprocess_masks(
207
- low_res_masks,
208
- input_size=resize_list[i],
209
- original_size=original_size_list[i],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  )
211
- pred_masks.append(pred_mask[:, 0])
212
-
213
- return output_ids, pred_masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from peft import (LoraConfig, get_peft_model)
7
+ from transformers import BitsAndBytesConfig, CLIPVisionModel
 
 
 
 
 
 
 
 
 
 
8
 
9
+ from .llava.model.llava import LlavaLlamaForCausalLM
10
+ from .segment_anything import build_sam_vit_h
11
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
12
+ DEFAULT_IMAGE_PATCH_TOKEN)
13
 
14
+ def dice_loss(
15
+ inputs: torch.Tensor,
16
+ targets: torch.Tensor,
17
+ num_masks: float,
18
+ scale=1000, # 100000.0,
19
+ eps=1e-6,
20
+ ):
21
+ """
22
+ Compute the DICE loss, similar to generalized IOU for masks
23
+ Args:
24
+ inputs: A float tensor of arbitrary shape.
25
+ The predictions for each example.
26
+ targets: A float tensor with the same shape as inputs. Stores the binary
27
+ classification label for each element in inputs
28
+ (0 for the negative class and 1 for the positive class).
29
+ """
30
+ inputs = inputs.sigmoid()
31
+ inputs = inputs.flatten(1, 2)
32
+ targets = targets.flatten(1, 2)
33
+ numerator = 2 * (inputs / scale * targets).sum(-1)
34
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
35
+ loss = 1 - (numerator + eps) / (denominator + eps)
36
+ loss = loss.sum() / (num_masks + 1e-8)
37
+ return loss
38
 
 
 
39
 
40
+ def sigmoid_ce_loss(
41
+ inputs: torch.Tensor,
42
+ targets: torch.Tensor,
43
+ num_masks: float,
44
+ ):
45
+ """
46
+ Args:
47
+ inputs: A float tensor of arbitrary shape.
48
+ The predictions for each example.
49
+ targets: A float tensor with the same shape as inputs. Stores the binary
50
+ classification label for each element in inputs
51
+ (0 for the negative class and 1 for the positive class).
52
+ Returns:
53
+ Loss tensor
54
+ """
55
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
56
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
57
+ return loss
58
 
 
59
 
60
  class LISA(nn.Module):
61
+ def __init__(
62
+ self,
63
+ local_rank,
64
+ seg_token_idx,
65
+ tokenizer,
66
+ llm_version,
67
+ lora_r,
68
+ precision,
69
+ load_in_4bit=False,
70
+ load_in_8bit=False,
71
+ lora_target_modules=["q_proj", "v_proj"],
72
+ lora_alpha=16,
73
+ lora_dropout=0.05,
74
+ vision_tower="openai/clip-vit-large-patch14",
75
+ mm_vision_select_layer=-2,
76
+ freeze_lm=True,
77
+ train_mask_decoder=True,
78
+ out_dim=256,
79
+ ce_loss_weight=1.0,
80
+ dice_loss_weight=0.5,
81
+ bce_loss_weight=2.0,
82
+ vision_pretrained=None,
83
+ ):
84
+ super().__init__()
85
+ self.local_rank = local_rank
86
+ self.tokenizer = tokenizer
87
+ self.image_token = tokenizer.cls_token_id
88
+ self.precision = precision
89
+ self.ce_loss_weight = ce_loss_weight
90
+ self.dice_loss_weight = dice_loss_weight
91
+ self.bce_loss_weight = bce_loss_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # LLaVA
94
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
95
+ num_new_tokens = tokenizer.add_tokens(
96
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
97
+ )
98
  if precision == "bf16":
99
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(
100
+ llm_version,
101
+ torch_dtype=torch.bfloat16,
102
+ cache_dir=None,
103
+ low_cpu_mem_usage=True,
104
+ )
105
  elif precision == "fp16":
106
+ if load_in_4bit:
107
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(
108
+ llm_version,
109
+ load_in_4bit=True,
110
+ cache_dir=None,
111
+ low_cpu_mem_usage=True,
112
+ device_map="auto",
113
+ quantization_config=BitsAndBytesConfig(
114
+ load_in_4bit=True,
115
+ bnb_4bit_compute_dtype=torch.float16,
116
+ bnb_4bit_use_double_quant=True,
117
+ bnb_4bit_quant_type="nf4",
118
+ ),
119
+ )
120
+ elif load_in_8bit:
121
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(
122
+ llm_version,
123
+ load_in_8bit=True,
124
+ cache_dir=None,
125
+ low_cpu_mem_usage=True,
126
+ device_map="auto",
127
+ )
128
+ else:
129
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(
130
+ llm_version,
131
+ torch_dtype=torch.half,
132
+ cache_dir=None,
133
+ low_cpu_mem_usage=True,
134
+ )
135
  else:
136
+ self.lm = LlavaLlamaForCausalLM.from_pretrained(
137
+ llm_version,
138
+ torch_dtype=torch.float32,
139
+ cache_dir=None,
140
+ low_cpu_mem_usage=True,
141
+ )
142
+
143
+ self.lm.enable_input_require_grads()
144
+ self.lm.gradient_checkpointing_enable()
145
+ self.lm.config.use_cache = False
146
+ model_vision_dict = self.lm.get_model().initialize_vision_modules(
147
+ vision_tower=vision_tower,
148
+ mm_vision_select_layer=mm_vision_select_layer,
149
+ precision=precision,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
+ vision_config = model_vision_dict["vision_config"]
152
+ vision_tower = self.lm.get_model().vision_tower[0]
153
+ self.lm.model.config.eos_token_id = tokenizer.eos_token_id
154
+ self.lm.model.config.bos_token_id = tokenizer.bos_token_id
155
+ self.lm.model.config.pad_token_id = tokenizer.pad_token_id
156
 
157
+ if vision_tower.device.type == "meta":
158
+ if precision == "bf16":
159
+ vision_tower = CLIPVisionModel.from_pretrained(
160
+ vision_tower.config._name_or_path,
161
+ torch_dtype=torch.bfloat16,
162
+ low_cpu_mem_usage=True,
163
+ ).cuda(local_rank)
164
+ elif precision == "fp16":
165
+ vision_tower = CLIPVisionModel.from_pretrained(
166
+ vision_tower.config._name_or_path,
167
+ torch_dtype=torch.half,
168
+ low_cpu_mem_usage=True,
169
+ ).cuda(local_rank)
170
+ else:
171
+ vision_tower = CLIPVisionModel.from_pretrained(
172
+ vision_tower.config._name_or_path,
173
+ torch_dtype=torch.float32,
174
+ low_cpu_mem_usage=True,
175
+ ).cuda(local_rank)
176
+ self.lm.get_model().vision_tower[0] = vision_tower
177
+ else:
178
+ if precision == "bf16":
179
+ vision_tower.to(device="cuda", dtype=torch.bfloat16)
180
+ elif precision == "fp16":
181
+ vision_tower.to(device="cuda", dtype=torch.half)
182
+ else:
183
+ vision_tower.to(device="cuda", dtype=torch.float32)
184
+
185
+ self.lm.config.tune_mm_mlp_adapter = False
186
+ self.lm.config.freeze_mm_mlp_adapter = False
187
+ self.lm.config.mm_use_im_start_end = True
188
+ vision_config.use_im_start_end = True
189
+ self.lm.config.sep_image_conv_front = False
190
+
191
+ self.lm.initialize_vision_tokenizer(
192
+ mm_use_im_start_end=True,
193
+ tokenizer=tokenizer,
194
+ num_new_tokens=num_new_tokens,
195
+ device=local_rank,
196
+ tune_mm_mlp_adapter=False,
197
  )
198
+ if freeze_lm:
199
+ for n, param in self.lm.named_parameters():
200
+ param.requires_grad = False
201
+
202
+ # LoRA
203
+ if lora_r > 0:
204
+ config = LoraConfig(
205
+ r=lora_r,
206
+ lora_alpha=lora_alpha,
207
+ target_modules=lora_target_modules,
208
+ lora_dropout=lora_dropout,
209
+ bias="none",
210
+ task_type="CAUSAL_LM",
211
+ )
212
+ self.lm = get_peft_model(self.lm, config)
213
+ self.lm.print_trainable_parameters()
214
+
215
+ self.llm_version = llm_version
216
+
217
+ self.seg_token_idx = seg_token_idx
218
+ self.lm.resize_token_embeddings(len(tokenizer))
219
+
220
+ for n, p in self.lm.named_parameters():
221
+ if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(tokenizer):
222
+ p.requires_grad = True
223
 
224
+ # SAM
225
+ self.visual_model = build_sam_vit_h(vision_pretrained)
226
+ for param in self.visual_model.parameters():
227
+ param.requires_grad = False
228
+ if train_mask_decoder:
229
+ self.visual_model.mask_decoder.train()
230
+ for param in self.visual_model.mask_decoder.parameters():
231
+ param.requires_grad = True
232
+
233
+ # Projection layer
234
+ in_dim = self.lm.config.hidden_size
235
+ text_fc = [
236
+ nn.Linear(in_dim, in_dim),
237
+ nn.ReLU(inplace=True),
238
+ nn.Linear(in_dim, out_dim),
239
+ nn.Dropout(0.0),
240
+ ]
241
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
242
+
243
+ def get_visual_embs(self, pixel_values: torch.FloatTensor):
244
+ with torch.no_grad():
245
+ image_embeddings = self.visual_model.image_encoder(pixel_values)
246
+ return image_embeddings
247
+
248
+ def forward(
249
+ self,
250
+ images: torch.FloatTensor,
251
+ images_clip: torch.FloatTensor,
252
+ input_ids: torch.LongTensor,
253
+ labels: torch.LongTensor,
254
+ attention_masks: torch.LongTensor,
255
+ offset: torch.LongTensor,
256
+ masks_list: List[torch.FloatTensor],
257
+ label_list: List[torch.Tensor],
258
+ resize_list: List[tuple],
259
+ inference: bool = False,
260
+ **kwargs,
261
+ ):
262
+ image_embeddings = self.get_visual_embs(images)
263
+ batch_size = image_embeddings.shape[0]
264
+ assert batch_size == len(offset) - 1
265
+
266
+ seg_token_mask = input_ids[:, 1:] == self.seg_token_idx
267
+ seg_token_mask = torch.cat(
268
+ [
269
+ seg_token_mask,
270
+ torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(self.local_rank),
271
+ ],
272
+ dim=1,
273
  )
274
+
275
+ if inference:
276
+ n_batch = 1
277
+ length = input_ids.shape[0]
278
+ assert images_clip.shape[0] == 1
279
+ images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
280
+
281
+ output_hidden_states = []
282
+ for i in range(n_batch):
283
+ start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
284
+ output_i = self.lm(
285
+ images=images_clip_extend[: end_i - start_i],
286
+ attention_mask=attention_masks[start_i:end_i],
287
+ input_ids=input_ids[start_i:end_i],
288
+ output_hidden_states=True,
289
+ )
290
+ output_hidden_states.append(output_i.hidden_states)
291
+ torch.cuda.empty_cache()
292
+
293
+ output_hidden_states_list = []
294
+ output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
295
+ output_hidden_states_list.append(output_hidden_states_level)
296
+ output_hidden_states = output_hidden_states_list
297
+ output = None
298
+
299
+ else:
300
+ images_clip_list = []
301
+ for i in range(len(offset) - 1):
302
+ start_i, end_i = offset[i], offset[i + 1]
303
+ images_clip_i = (
304
+ images_clip[i]
305
+ .unsqueeze(0)
306
+ .expand(end_i - start_i, -1, -1, -1)
307
+ .contiguous()
308
+ )
309
+ images_clip_list.append(images_clip_i)
310
+ images_clip = torch.cat(images_clip_list, dim=0)
311
+
312
+ output = self.lm(
313
+ images=images_clip,
314
+ attention_mask=attention_masks,
315
+ input_ids=input_ids,
316
+ labels=labels,
317
+ output_hidden_states=True,
318
+ )
319
+ output_hidden_states = output.hidden_states
320
+
321
+ hidden_states = []
322
+
323
+ assert len(self.text_hidden_fcs) == 1
324
+ hidden_states.append(self.text_hidden_fcs[0](output_hidden_states[-1]))
325
+
326
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
327
+
328
+ pred_embeddings = last_hidden_state[seg_token_mask]
329
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
330
+
331
+ seg_token_offset = seg_token_counts.cumsum(-1)
332
+ seg_token_offset = torch.cat(
333
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
334
+ )
335
+
336
+ seg_token_offset = seg_token_offset[offset]
337
+
338
+ pred_embeddings_ = []
339
+ for i in range(len(seg_token_offset) - 1):
340
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
341
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
342
+ pred_embeddings = pred_embeddings_
343
+
344
+ multimask_output = False
345
+ pred_masks = []
346
+ for i in range(len(pred_embeddings)):
347
+ sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
348
+ points=None,
349
+ boxes=None,
350
+ masks=None,
351
+ text_embeds=pred_embeddings[i].unsqueeze(1),
352
+ )
353
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
354
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
355
+ image_embeddings=image_embeddings[i].unsqueeze(0),
356
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
357
+ sparse_prompt_embeddings=sparse_embeddings,
358
+ dense_prompt_embeddings=dense_embeddings,
359
+ multimask_output=multimask_output,
360
+ )
361
+ pred_mask = self.visual_model.postprocess_masks(
362
+ low_res_masks,
363
+ input_size=resize_list[i],
364
+ original_size=label_list[i].shape,
365
+ )
366
+ pred_masks.append(pred_mask[:, 0])
367
+
368
+ model_output = output
369
+ gt_masks = masks_list
370
+
371
+ if inference:
372
+ return {
373
+ "pred_masks": pred_masks,
374
+ "gt_masks": gt_masks,
375
+ }
376
+
377
+ output = model_output.logits
378
+
379
+ ce_loss = model_output.loss
380
+ ce_loss = ce_loss * self.ce_loss_weight
381
+ loss = ce_loss
382
+ mask_bce_loss = 0
383
+ mask_dice_loss = 0
384
+ num_masks = 0
385
+ for batch_idx in range(len(pred_masks)):
386
+ gt_mask = gt_masks[batch_idx]
387
+ pred_mask = pred_masks[batch_idx]
388
+
389
+ assert (
390
+ gt_mask.shape[0] == pred_mask.shape[0]
391
+ ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
392
+ gt_mask.shape, pred_mask.shape
393
+ )
394
+ mask_bce_loss += (
395
+ sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
396
+ * gt_mask.shape[0]
397
+ )
398
+ mask_dice_loss += (
399
+ dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
400
+ * gt_mask.shape[0]
401
+ )
402
+ num_masks += gt_mask.shape[0]
403
+
404
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
405
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
406
+ mask_loss = mask_bce_loss + mask_dice_loss
407
+
408
+ loss += mask_loss
409
+
410
+ return {
411
+ "loss": loss,
412
+ "ce_loss": ce_loss,
413
+ "mask_bce_loss": mask_bce_loss,
414
+ "mask_dice_loss": mask_dice_loss,
415
+ "mask_loss": mask_loss,
416
+ }
417
+
418
+ def evaluate(
419
+ self,
420
+ images_clip,
421
+ images,
422
+ input_ids,
423
+ resize_list,
424
+ original_size_list,
425
+ max_new_tokens=32,
426
+ tokenizer=None,
427
+ ):
428
+ with torch.no_grad():
429
+ outputs = self.lm.generate(
430
+ images=images_clip,
431
+ input_ids=input_ids,
432
+ max_new_tokens=max_new_tokens,
433
+ num_beams=1,
434
+ output_hidden_states=True,
435
+ return_dict_in_generate=True,
436
+ )
437
+ output_hidden_states = outputs.hidden_states[-1]
438
+ output_ids = outputs.sequences
439
+
440
+ seg_token_mask = output_ids[:, 1:] == self.seg_token_idx
441
+
442
+ hidden_states = []
443
+
444
+ assert len(self.text_hidden_fcs) == 1
445
+ hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
446
+
447
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
448
+ pred_embeddings = last_hidden_state[seg_token_mask]
449
+
450
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
451
+ seg_token_offset = seg_token_counts.cumsum(-1)
452
+ seg_token_offset = torch.cat(
453
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
454
+ )
455
+
456
+ pred_embeddings_ = []
457
+ for i in range(len(seg_token_offset) - 1):
458
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
459
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
460
+ pred_embeddings = pred_embeddings_
461
+
462
+ image_embeddings = self.get_visual_embs(images)
463
+
464
+ multimask_output = False
465
+ pred_masks = []
466
+ for i in range(len(pred_embeddings)):
467
+ sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
468
+ points=None,
469
+ boxes=None,
470
+ masks=None,
471
+ text_embeds=pred_embeddings[i].unsqueeze(1),
472
+ )
473
+
474
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
475
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
476
+ image_embeddings=image_embeddings[i].unsqueeze(0),
477
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
478
+ sparse_prompt_embeddings=sparse_embeddings,
479
+ dense_prompt_embeddings=dense_embeddings,
480
+ multimask_output=multimask_output,
481
+ )
482
+
483
+ pred_mask = self.visual_model.postprocess_masks(
484
+ low_res_masks,
485
+ input_size=resize_list[i],
486
+ original_size=original_size_list[i],
487
+ )
488
+ pred_masks.append(pred_mask[:, 0])
489
+
490
+ return output_ids, pred_masks
model/llava/conversation.py CHANGED
@@ -1,10 +1,11 @@
1
  import dataclasses
2
- from enum import auto, Enum
3
  from typing import List, Tuple
4
 
5
 
6
  class SeparatorStyle(Enum):
7
  """Different separator style."""
 
8
  SINGLE = auto()
9
  TWO = auto()
10
  MPT = auto()
@@ -13,6 +14,7 @@ class SeparatorStyle(Enum):
13
  @dataclasses.dataclass
14
  class Conversation:
15
  """A class that keeps all conversation history."""
 
16
  system: str
17
  roles: List[str]
18
  messages: List[List[str]]
@@ -64,33 +66,43 @@ class Conversation:
64
 
65
  def get_images(self, return_pil=False):
66
  images = []
67
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
68
  if i % 2 == 0:
69
  if type(msg) is tuple:
70
  import base64
71
  from io import BytesIO
 
72
  from PIL import Image
 
73
  msg, image, image_process_mode = msg
74
  if image_process_mode == "Pad":
 
75
  def expand2square(pil_img, background_color=(122, 116, 104)):
76
  width, height = pil_img.size
77
  if width == height:
78
  return pil_img
79
  elif width > height:
80
- result = Image.new(pil_img.mode, (width, width), background_color)
 
 
81
  result.paste(pil_img, (0, (width - height) // 2))
82
  return result
83
  else:
84
- result = Image.new(pil_img.mode, (height, height), background_color)
 
 
85
  result.paste(pil_img, ((height - width) // 2, 0))
86
  return result
 
87
  image = expand2square(image)
88
  elif image_process_mode == "Crop":
89
  pass
90
  elif image_process_mode == "Resize":
91
  image = image.resize((224, 224))
92
  else:
93
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
 
 
94
  max_hw, min_hw = max(image.size), min(image.size)
95
  aspect_ratio = max_hw / min_hw
96
  max_len, min_len = 800, 400
@@ -113,11 +125,12 @@ class Conversation:
113
 
114
  def to_gradio_chatbot(self):
115
  ret = []
116
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
117
  if i % 2 == 0:
118
  if type(msg) is tuple:
119
  import base64
120
  from io import BytesIO
 
121
  msg, image, image_process_mode = msg
122
  max_hw, min_hw = max(image.size), min(image.size)
123
  aspect_ratio = max_hw / min_hw
@@ -135,7 +148,7 @@ class Conversation:
135
  image.save(buffered, format="JPEG")
136
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
137
  img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
138
- msg = msg.replace('<image>', img_str)
139
  ret.append([msg, None])
140
  else:
141
  ret[-1][-1] = msg
@@ -149,14 +162,17 @@ class Conversation:
149
  offset=self.offset,
150
  sep_style=self.sep_style,
151
  sep=self.sep,
152
- sep2=self.sep2)
 
153
 
154
  def dict(self):
155
  if len(self.get_images()) > 0:
156
  return {
157
  "system": self.system,
158
  "roles": self.roles,
159
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
 
 
160
  "offset": self.offset,
161
  "sep": self.sep,
162
  "sep2": self.sep2,
@@ -173,11 +189,12 @@ class Conversation:
173
 
174
  conv_v1 = Conversation(
175
  system="A chat between a curious human and an artificial intelligence assistant. "
176
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
177
  roles=("Human", "Assistant"),
178
  messages=(
179
  ("Human", "Give three tips for staying healthy."),
180
- ("Assistant",
 
181
  "Sure, here are three tips for staying healthy:\n"
182
  "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
183
  "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
@@ -191,7 +208,8 @@ conv_v1 = Conversation(
191
  "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
192
  "and mental health. Adults should aim for seven to nine hours of sleep per night. "
193
  "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
194
- "help improve the quality of your sleep.")
 
195
  ),
196
  offset=2,
197
  sep_style=SeparatorStyle.SINGLE,
@@ -200,11 +218,15 @@ conv_v1 = Conversation(
200
 
201
  conv_v1_2 = Conversation(
202
  system="A chat between a curious human and an artificial intelligence assistant. "
203
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
204
  roles=("Human", "Assistant"),
205
  messages=(
206
- ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
207
- ("Assistant",
 
 
 
 
208
  "Renewable energy sources are those that can be replenished naturally in a relatively "
209
  "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
210
  "Non-renewable energy sources, on the other hand, are finite and will eventually be "
@@ -222,7 +244,8 @@ conv_v1_2 = Conversation(
222
  "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
223
  "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
224
  "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
225
- "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
 
226
  ),
227
  offset=2,
228
  sep_style=SeparatorStyle.SINGLE,
@@ -280,12 +303,12 @@ conv_bair_v1 = Conversation(
280
 
281
  simple_conv = Conversation(
282
  system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
283
- "You are designed to assist human with a variety of tasks using natural language."
284
- "Follow the instructions carefully.",
285
  roles=("Human", "Assistant"),
286
  messages=(
287
  ("Human", "Hi!"),
288
- ("Assistant", "Hi there! How can I help you today?\n")
289
  ),
290
  offset=2,
291
  sep_style=SeparatorStyle.SINGLE,
@@ -294,12 +317,12 @@ simple_conv = Conversation(
294
 
295
  simple_conv_multimodal = Conversation(
296
  system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
297
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
298
- "Follow the instructions carefully and explain your answers in detail.",
299
  roles=("Human", "Assistant"),
300
  messages=(
301
  ("Human", "Hi!"),
302
- ("Assistant", "Hi there! How can I help you today?\n")
303
  ),
304
  offset=2,
305
  sep_style=SeparatorStyle.SINGLE,
@@ -321,12 +344,12 @@ simple_conv_mpt_multimodal = Conversation(
321
 
322
  simple_conv_legacy = Conversation(
323
  system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
324
- "You are designed to assist human with a variety of tasks using natural language."
325
- "Follow the instructions carefully.",
326
  roles=("Human", "Assistant"),
327
  messages=(
328
  ("Human", "Hi!\n\n### Response:"),
329
- ("Assistant", "Hi there! How can I help you today?\n")
330
  ),
331
  offset=2,
332
  sep_style=SeparatorStyle.SINGLE,
@@ -335,8 +358,8 @@ simple_conv_legacy = Conversation(
335
 
336
  conv_llava_v1 = Conversation(
337
  system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
338
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
339
- "Follow the instructions carefully and explain your answers in detail.",
340
  roles=("USER", "ASSISTANT"),
341
  version="v1",
342
  messages=(),
@@ -354,7 +377,6 @@ conv_templates = {
354
  "multimodal": simple_conv_multimodal,
355
  "mpt_multimodal": simple_conv_mpt_multimodal,
356
  "llava_v1": conv_llava_v1,
357
-
358
  # fastchat
359
  "v1": conv_v1_2,
360
  "bair_v1": conv_bair_v1,
 
1
  import dataclasses
2
+ from enum import Enum, auto
3
  from typing import List, Tuple
4
 
5
 
6
  class SeparatorStyle(Enum):
7
  """Different separator style."""
8
+
9
  SINGLE = auto()
10
  TWO = auto()
11
  MPT = auto()
 
14
  @dataclasses.dataclass
15
  class Conversation:
16
  """A class that keeps all conversation history."""
17
+
18
  system: str
19
  roles: List[str]
20
  messages: List[List[str]]
 
66
 
67
  def get_images(self, return_pil=False):
68
  images = []
69
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
70
  if i % 2 == 0:
71
  if type(msg) is tuple:
72
  import base64
73
  from io import BytesIO
74
+
75
  from PIL import Image
76
+
77
  msg, image, image_process_mode = msg
78
  if image_process_mode == "Pad":
79
+
80
  def expand2square(pil_img, background_color=(122, 116, 104)):
81
  width, height = pil_img.size
82
  if width == height:
83
  return pil_img
84
  elif width > height:
85
+ result = Image.new(
86
+ pil_img.mode, (width, width), background_color
87
+ )
88
  result.paste(pil_img, (0, (width - height) // 2))
89
  return result
90
  else:
91
+ result = Image.new(
92
+ pil_img.mode, (height, height), background_color
93
+ )
94
  result.paste(pil_img, ((height - width) // 2, 0))
95
  return result
96
+
97
  image = expand2square(image)
98
  elif image_process_mode == "Crop":
99
  pass
100
  elif image_process_mode == "Resize":
101
  image = image.resize((224, 224))
102
  else:
103
+ raise ValueError(
104
+ f"Invalid image_process_mode: {image_process_mode}"
105
+ )
106
  max_hw, min_hw = max(image.size), min(image.size)
107
  aspect_ratio = max_hw / min_hw
108
  max_len, min_len = 800, 400
 
125
 
126
  def to_gradio_chatbot(self):
127
  ret = []
128
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
129
  if i % 2 == 0:
130
  if type(msg) is tuple:
131
  import base64
132
  from io import BytesIO
133
+
134
  msg, image, image_process_mode = msg
135
  max_hw, min_hw = max(image.size), min(image.size)
136
  aspect_ratio = max_hw / min_hw
 
148
  image.save(buffered, format="JPEG")
149
  img_b64_str = base64.b64encode(buffered.getvalue()).decode()
150
  img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
151
+ msg = msg.replace("<image>", img_str)
152
  ret.append([msg, None])
153
  else:
154
  ret[-1][-1] = msg
 
162
  offset=self.offset,
163
  sep_style=self.sep_style,
164
  sep=self.sep,
165
+ sep2=self.sep2,
166
+ )
167
 
168
  def dict(self):
169
  if len(self.get_images()) > 0:
170
  return {
171
  "system": self.system,
172
  "roles": self.roles,
173
+ "messages": [
174
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
175
+ ],
176
  "offset": self.offset,
177
  "sep": self.sep,
178
  "sep2": self.sep2,
 
189
 
190
  conv_v1 = Conversation(
191
  system="A chat between a curious human and an artificial intelligence assistant. "
192
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
193
  roles=("Human", "Assistant"),
194
  messages=(
195
  ("Human", "Give three tips for staying healthy."),
196
+ (
197
+ "Assistant",
198
  "Sure, here are three tips for staying healthy:\n"
199
  "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
200
  "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
 
208
  "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
209
  "and mental health. Adults should aim for seven to nine hours of sleep per night. "
210
  "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
211
+ "help improve the quality of your sleep.",
212
+ ),
213
  ),
214
  offset=2,
215
  sep_style=SeparatorStyle.SINGLE,
 
218
 
219
  conv_v1_2 = Conversation(
220
  system="A chat between a curious human and an artificial intelligence assistant. "
221
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
222
  roles=("Human", "Assistant"),
223
  messages=(
224
+ (
225
+ "Human",
226
+ "What are the key differences between renewable and non-renewable energy sources?",
227
+ ),
228
+ (
229
+ "Assistant",
230
  "Renewable energy sources are those that can be replenished naturally in a relatively "
231
  "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
232
  "Non-renewable energy sources, on the other hand, are finite and will eventually be "
 
244
  "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
245
  "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
246
  "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
247
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
248
+ ),
249
  ),
250
  offset=2,
251
  sep_style=SeparatorStyle.SINGLE,
 
303
 
304
  simple_conv = Conversation(
305
  system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
306
+ "You are designed to assist human with a variety of tasks using natural language."
307
+ "Follow the instructions carefully.",
308
  roles=("Human", "Assistant"),
309
  messages=(
310
  ("Human", "Hi!"),
311
+ ("Assistant", "Hi there! How can I help you today?\n"),
312
  ),
313
  offset=2,
314
  sep_style=SeparatorStyle.SINGLE,
 
317
 
318
  simple_conv_multimodal = Conversation(
319
  system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
320
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
321
+ "Follow the instructions carefully and explain your answers in detail.",
322
  roles=("Human", "Assistant"),
323
  messages=(
324
  ("Human", "Hi!"),
325
+ ("Assistant", "Hi there! How can I help you today?\n"),
326
  ),
327
  offset=2,
328
  sep_style=SeparatorStyle.SINGLE,
 
344
 
345
  simple_conv_legacy = Conversation(
346
  system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
347
+ "You are designed to assist human with a variety of tasks using natural language."
348
+ "Follow the instructions carefully.",
349
  roles=("Human", "Assistant"),
350
  messages=(
351
  ("Human", "Hi!\n\n### Response:"),
352
+ ("Assistant", "Hi there! How can I help you today?\n"),
353
  ),
354
  offset=2,
355
  sep_style=SeparatorStyle.SINGLE,
 
358
 
359
  conv_llava_v1 = Conversation(
360
  system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
361
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
362
+ "Follow the instructions carefully and explain your answers in detail.",
363
  roles=("USER", "ASSISTANT"),
364
  version="v1",
365
  messages=(),
 
377
  "multimodal": simple_conv_multimodal,
378
  "mpt_multimodal": simple_conv_mpt_multimodal,
379
  "llava_v1": conv_llava_v1,
 
380
  # fastchat
381
  "v1": conv_v1_2,
382
  "bair_v1": conv_bair_v1,
model/llava/eval/eval_gpt_review.py CHANGED
@@ -1,25 +1,29 @@
1
  import argparse
2
  import json
3
  import os
 
4
 
5
  import openai
6
- import tqdm
7
  import ray
8
- import time
 
9
 
10
  @ray.remote(num_cpus=4)
11
  def get_eval(content: str, max_tokens: int):
12
  while True:
13
  try:
14
  response = openai.ChatCompletion.create(
15
- model='gpt-4',
16
- messages=[{
17
- 'role': 'system',
18
- 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
- }, {
20
- 'role': 'user',
21
- 'content': content,
22
- }],
 
 
 
23
  temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
  max_tokens=max_tokens,
25
  )
@@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int):
30
  print(e)
31
  time.sleep(1)
32
 
33
- print('success!')
34
- return response['choices'][0]['message']['content']
35
 
36
 
37
  def parse_score(review):
38
  try:
39
- score_pair = review.split('\n')[0]
40
- score_pair = score_pair.replace(',', ' ')
41
- sp = score_pair.split(' ')
42
  if len(sp) == 2:
43
  return [float(sp[0]), float(sp[1])]
44
  else:
45
- print('error', review)
46
  return [-1, -1]
47
  except Exception as e:
48
  print(e)
49
- print('error', review)
50
  return [-1, -1]
51
 
52
 
53
- if __name__ == '__main__':
54
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
55
- parser.add_argument('-q', '--question')
56
  # parser.add_argument('-a', '--answer')
57
- parser.add_argument('-a', '--answer-list', nargs='+', default=[])
58
- parser.add_argument('-r', '--rule')
59
- parser.add_argument('-o', '--output')
60
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 
 
 
 
 
61
  args = parser.parse_args()
62
 
63
  ray.init()
@@ -65,9 +74,9 @@ if __name__ == '__main__':
65
  f_q = open(os.path.expanduser(args.question))
66
  f_ans1 = open(os.path.expanduser(args.answer_list[0]))
67
  f_ans2 = open(os.path.expanduser(args.answer_list[1]))
68
- rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
69
 
70
- review_file = open(f'{args.output}', 'w')
71
 
72
  js_list = []
73
  handles = []
@@ -80,23 +89,28 @@ if __name__ == '__main__':
80
  ans1 = json.loads(ans1_js)
81
  ans2 = json.loads(ans2_js)
82
 
83
- category = json.loads(ques_js)['category']
84
  if category in rule_dict:
85
  rule = rule_dict[category]
86
  else:
87
- rule = rule_dict['default']
88
- prompt = rule['prompt']
89
- role = rule['role']
90
- content = (f'[Question]\n{ques["text"]}\n\n'
91
- f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
92
- f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
93
- f'[System]\n{prompt}\n\n')
94
- js_list.append({
95
- 'id': idx+1,
96
- 'question_id': ques['question_id'],
97
- 'answer1_id': ans1['answer_id'],
98
- 'answer2_id': ans2['answer_id'],
99
- 'category': category})
 
 
 
 
 
100
  idx += 1
101
  handles.append(get_eval.remote(content, args.max_tokens))
102
  # To avoid the rate limit set by OpenAI
@@ -105,7 +119,7 @@ if __name__ == '__main__':
105
  reviews = ray.get(handles)
106
  for idx, review in enumerate(reviews):
107
  scores = parse_score(review)
108
- js_list[idx]['content'] = review
109
- js_list[idx]['tuple'] = scores
110
- review_file.write(json.dumps(js_list[idx]) + '\n')
111
  review_file.close()
 
1
  import argparse
2
  import json
3
  import os
4
+ import time
5
 
6
  import openai
 
7
  import ray
8
+ import tqdm
9
+
10
 
11
  @ray.remote(num_cpus=4)
12
  def get_eval(content: str, max_tokens: int):
13
  while True:
14
  try:
15
  response = openai.ChatCompletion.create(
16
+ model="gpt-4",
17
+ messages=[
18
+ {
19
+ "role": "system",
20
+ "content": "You are a helpful and precise assistant for checking the quality of the answer.",
21
+ },
22
+ {
23
+ "role": "user",
24
+ "content": content,
25
+ },
26
+ ],
27
  temperature=0.2, # TODO: figure out which temperature is best for evaluation
28
  max_tokens=max_tokens,
29
  )
 
34
  print(e)
35
  time.sleep(1)
36
 
37
+ print("success!")
38
+ return response["choices"][0]["message"]["content"]
39
 
40
 
41
  def parse_score(review):
42
  try:
43
+ score_pair = review.split("\n")[0]
44
+ score_pair = score_pair.replace(",", " ")
45
+ sp = score_pair.split(" ")
46
  if len(sp) == 2:
47
  return [float(sp[0]), float(sp[1])]
48
  else:
49
+ print("error", review)
50
  return [-1, -1]
51
  except Exception as e:
52
  print(e)
53
+ print("error", review)
54
  return [-1, -1]
55
 
56
 
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.")
59
+ parser.add_argument("-q", "--question")
60
  # parser.add_argument('-a', '--answer')
61
+ parser.add_argument("-a", "--answer-list", nargs="+", default=[])
62
+ parser.add_argument("-r", "--rule")
63
+ parser.add_argument("-o", "--output")
64
+ parser.add_argument(
65
+ "--max-tokens",
66
+ type=int,
67
+ default=1024,
68
+ help="maximum number of tokens produced in the output",
69
+ )
70
  args = parser.parse_args()
71
 
72
  ray.init()
 
74
  f_q = open(os.path.expanduser(args.question))
75
  f_ans1 = open(os.path.expanduser(args.answer_list[0]))
76
  f_ans2 = open(os.path.expanduser(args.answer_list[1]))
77
+ rule_dict = json.load(open(os.path.expanduser(args.rule), "r"))
78
 
79
+ review_file = open(f"{args.output}", "w")
80
 
81
  js_list = []
82
  handles = []
 
89
  ans1 = json.loads(ans1_js)
90
  ans2 = json.loads(ans2_js)
91
 
92
+ category = json.loads(ques_js)["category"]
93
  if category in rule_dict:
94
  rule = rule_dict[category]
95
  else:
96
+ rule = rule_dict["default"]
97
+ prompt = rule["prompt"]
98
+ role = rule["role"]
99
+ content = (
100
+ f'[Question]\n{ques["text"]}\n\n'
101
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
102
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
103
+ f"[System]\n{prompt}\n\n"
104
+ )
105
+ js_list.append(
106
+ {
107
+ "id": idx + 1,
108
+ "question_id": ques["question_id"],
109
+ "answer1_id": ans1["answer_id"],
110
+ "answer2_id": ans2["answer_id"],
111
+ "category": category,
112
+ }
113
+ )
114
  idx += 1
115
  handles.append(get_eval.remote(content, args.max_tokens))
116
  # To avoid the rate limit set by OpenAI
 
119
  reviews = ray.get(handles)
120
  for idx, review in enumerate(reviews):
121
  scores = parse_score(review)
122
+ js_list[idx]["content"] = review
123
+ js_list[idx]["tuple"] = scores
124
+ review_file.write(json.dumps(js_list[idx]) + "\n")
125
  review_file.close()
model/llava/eval/eval_gpt_review_visual.py CHANGED
@@ -1,25 +1,29 @@
1
  import argparse
2
  import json
3
  import os
 
4
 
5
  import openai
6
- import tqdm
7
  import ray
8
- import time
 
9
 
10
  @ray.remote(num_cpus=4)
11
  def get_eval(content: str, max_tokens: int):
12
  while True:
13
  try:
14
  response = openai.ChatCompletion.create(
15
- model='gpt-4',
16
- messages=[{
17
- 'role': 'system',
18
- 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19
- }, {
20
- 'role': 'user',
21
- 'content': content,
22
- }],
 
 
 
23
  temperature=0.2, # TODO: figure out which temperature is best for evaluation
24
  max_tokens=max_tokens,
25
  )
@@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int):
30
  print(e)
31
  time.sleep(1)
32
 
33
- print('success!')
34
- return response['choices'][0]['message']['content']
35
 
36
 
37
  def parse_score(review):
38
  try:
39
- score_pair = review.split('\n')[0]
40
- score_pair = score_pair.replace(',', ' ')
41
- sp = score_pair.split(' ')
42
  if len(sp) == 2:
43
  return [float(sp[0]), float(sp[1])]
44
  else:
45
- print('error', review)
46
  return [-1, -1]
47
  except Exception as e:
48
  print(e)
49
- print('error', review)
50
  return [-1, -1]
51
 
52
 
53
- if __name__ == '__main__':
54
- parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
55
- parser.add_argument('-q', '--question')
56
- parser.add_argument('-c', '--context')
57
- parser.add_argument('-a', '--answer-list', nargs='+', default=[])
58
- parser.add_argument('-r', '--rule')
59
- parser.add_argument('-o', '--output')
60
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 
 
 
 
 
61
  args = parser.parse_args()
62
 
63
  ray.init()
@@ -65,12 +74,12 @@ if __name__ == '__main__':
65
  f_q = open(os.path.expanduser(args.question))
66
  f_ans1 = open(os.path.expanduser(args.answer_list[0]))
67
  f_ans2 = open(os.path.expanduser(args.answer_list[1]))
68
- rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
69
 
70
- review_file = open(f'{args.output}', 'w')
71
 
72
  context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
73
- image_to_context = {context['image']: context for context in context_list}
74
 
75
  js_list = []
76
  handles = []
@@ -80,28 +89,38 @@ if __name__ == '__main__':
80
  ans1 = json.loads(ans1_js)
81
  ans2 = json.loads(ans2_js)
82
 
83
- inst = image_to_context[ques['image']]
84
- cap_str = '\n'.join(inst['captions'])
85
- box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
 
 
 
 
 
86
 
87
- category = json.loads(ques_js)['category']
88
  if category in rule_dict:
89
  rule = rule_dict[category]
90
  else:
91
  assert False, f"Visual QA category not found in rule file: {category}."
92
- prompt = rule['prompt']
93
- role = rule['role']
94
- content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
95
- f'[Question]\n{ques["text"]}\n\n'
96
- f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
97
- f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
98
- f'[System]\n{prompt}\n\n')
99
- js_list.append({
100
- 'id': idx+1,
101
- 'question_id': ques['question_id'],
102
- 'answer1_id': ans1.get('answer_id', ans1['question_id']),
103
- 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
104
- 'category': category})
 
 
 
 
 
105
  idx += 1
106
  handles.append(get_eval.remote(content, args.max_tokens))
107
  # To avoid the rate limit set by OpenAI
@@ -110,7 +129,7 @@ if __name__ == '__main__':
110
  reviews = ray.get(handles)
111
  for idx, review in enumerate(reviews):
112
  scores = parse_score(review)
113
- js_list[idx]['content'] = review
114
- js_list[idx]['tuple'] = scores
115
- review_file.write(json.dumps(js_list[idx]) + '\n')
116
  review_file.close()
 
1
  import argparse
2
  import json
3
  import os
4
+ import time
5
 
6
  import openai
 
7
  import ray
8
+ import tqdm
9
+
10
 
11
  @ray.remote(num_cpus=4)
12
  def get_eval(content: str, max_tokens: int):
13
  while True:
14
  try:
15
  response = openai.ChatCompletion.create(
16
+ model="gpt-4",
17
+ messages=[
18
+ {
19
+ "role": "system",
20
+ "content": "You are a helpful and precise assistant for checking the quality of the answer.",
21
+ },
22
+ {
23
+ "role": "user",
24
+ "content": content,
25
+ },
26
+ ],
27
  temperature=0.2, # TODO: figure out which temperature is best for evaluation
28
  max_tokens=max_tokens,
29
  )
 
34
  print(e)
35
  time.sleep(1)
36
 
37
+ print("success!")
38
+ return response["choices"][0]["message"]["content"]
39
 
40
 
41
  def parse_score(review):
42
  try:
43
+ score_pair = review.split("\n")[0]
44
+ score_pair = score_pair.replace(",", " ")
45
+ sp = score_pair.split(" ")
46
  if len(sp) == 2:
47
  return [float(sp[0]), float(sp[1])]
48
  else:
49
+ print("error", review)
50
  return [-1, -1]
51
  except Exception as e:
52
  print(e)
53
+ print("error", review)
54
  return [-1, -1]
55
 
56
 
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.")
59
+ parser.add_argument("-q", "--question")
60
+ parser.add_argument("-c", "--context")
61
+ parser.add_argument("-a", "--answer-list", nargs="+", default=[])
62
+ parser.add_argument("-r", "--rule")
63
+ parser.add_argument("-o", "--output")
64
+ parser.add_argument(
65
+ "--max-tokens",
66
+ type=int,
67
+ default=1024,
68
+ help="maximum number of tokens produced in the output",
69
+ )
70
  args = parser.parse_args()
71
 
72
  ray.init()
 
74
  f_q = open(os.path.expanduser(args.question))
75
  f_ans1 = open(os.path.expanduser(args.answer_list[0]))
76
  f_ans2 = open(os.path.expanduser(args.answer_list[1]))
77
+ rule_dict = json.load(open(os.path.expanduser(args.rule), "r"))
78
 
79
+ review_file = open(f"{args.output}", "w")
80
 
81
  context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
82
+ image_to_context = {context["image"]: context for context in context_list}
83
 
84
  js_list = []
85
  handles = []
 
89
  ans1 = json.loads(ans1_js)
90
  ans2 = json.loads(ans2_js)
91
 
92
+ inst = image_to_context[ques["image"]]
93
+ cap_str = "\n".join(inst["captions"])
94
+ box_str = "\n".join(
95
+ [
96
+ f'{instance["category"]}: {instance["bbox"]}'
97
+ for instance in inst["instances"]
98
+ ]
99
+ )
100
 
101
+ category = json.loads(ques_js)["category"]
102
  if category in rule_dict:
103
  rule = rule_dict[category]
104
  else:
105
  assert False, f"Visual QA category not found in rule file: {category}."
106
+ prompt = rule["prompt"]
107
+ role = rule["role"]
108
+ content = (
109
+ f"[Context]\n{cap_str}\n\n{box_str}\n\n"
110
+ f'[Question]\n{ques["text"]}\n\n'
111
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
112
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
113
+ f"[System]\n{prompt}\n\n"
114
+ )
115
+ js_list.append(
116
+ {
117
+ "id": idx + 1,
118
+ "question_id": ques["question_id"],
119
+ "answer1_id": ans1.get("answer_id", ans1["question_id"]),
120
+ "answer2_id": ans2.get("answer_id", ans2["answer_id"]),
121
+ "category": category,
122
+ }
123
+ )
124
  idx += 1
125
  handles.append(get_eval.remote(content, args.max_tokens))
126
  # To avoid the rate limit set by OpenAI
 
129
  reviews = ray.get(handles)
130
  for idx, review in enumerate(reviews):
131
  scores = parse_score(review)
132
+ js_list[idx]["content"] = review
133
+ js_list[idx]["tuple"] = scores
134
+ review_file.write(json.dumps(js_list[idx]) + "\n")
135
  review_file.close()
model/llava/eval/eval_science_qa.py CHANGED
@@ -1,26 +1,26 @@
1
  import argparse
2
  import json
3
  import os
4
- import re
5
  import random
 
6
 
7
 
8
  def get_args():
9
  parser = argparse.ArgumentParser()
10
- parser.add_argument('--base-dir', type=str)
11
- parser.add_argument('--result-file', type=str)
12
- parser.add_argument('--output-file', type=str)
13
- parser.add_argument('--output-result', type=str)
14
- parser.add_argument('--split', type=str, default='test')
15
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
  return parser.parse_args()
17
 
18
 
19
  def convert_caps(results):
20
  fakecaps = []
21
  for result in results:
22
- image_id = result['question_id']
23
- caption = result['text']
24
  fakecaps.append({"image_id": int(image_id), "caption": caption})
25
  return fakecaps
26
 
@@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options):
29
  """
30
  Get the index (e.g. 2) from the prediction (e.g. 'C')
31
  """
32
- if prediction in options[:len(choices)]:
33
  return options.index(prediction)
34
  else:
35
  return random.choice(range(len(choices)))
@@ -39,61 +39,65 @@ if __name__ == "__main__":
39
  args = get_args()
40
 
41
  base_dir = args.base_dir
42
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
 
 
43
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
  predictions = [json.loads(line) for line in open(args.result_file)]
45
- predictions = {pred['question_id']: pred for pred in predictions}
46
  split_problems = {idx: problems[idx] for idx in split_indices}
47
 
48
- results = {'correct': [], 'incorrect': []}
49
  sqa_results = {}
50
- sqa_results['acc'] = None
51
- sqa_results['correct'] = None
52
- sqa_results['count'] = None
53
- sqa_results['results'] = {}
54
- sqa_results['outputs'] = {}
55
 
56
  for prob_id, prob in split_problems.items():
57
  if prob_id not in predictions:
58
  continue
59
  pred = predictions[prob_id]
60
- pred_text = pred['text']
61
 
62
- pattern = re.compile(r'The answer is ([A-Z]).')
63
  res = pattern.findall(pred_text)
64
  if len(res) == 1:
65
  answer = res[0] # 'A', 'B', ...
66
  else:
67
  answer = "FAILED"
68
 
69
- pred_idx = get_pred_idx(answer, prob['choices'], args.options)
70
 
71
  analysis = {
72
- 'question_id': prob_id,
73
- 'parsed_ans': answer,
74
- 'ground_truth': args.options[prob['answer']],
75
- 'question': pred['prompt'],
76
- 'pred': pred_text,
77
- 'is_multimodal': '<image>' in pred['prompt'],
78
  }
79
 
80
- sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
81
- sqa_results['outputs'][prob_id] = pred_text
 
 
82
 
83
- if pred_idx == prob['answer']:
84
- results['correct'].append(analysis)
85
  else:
86
- results['incorrect'].append(analysis)
87
 
88
- correct = len(results['correct'])
89
- total = len(results['correct']) + len(results['incorrect'])
90
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
91
 
92
- sqa_results['acc'] = correct / total * 100
93
- sqa_results['correct'] = correct
94
- sqa_results['count'] = total
95
 
96
- with open(args.output_file, 'w') as f:
97
  json.dump(results, f, indent=2)
98
- with open(args.output_result, 'w') as f:
99
  json.dump(sqa_results, f, indent=2)
 
1
  import argparse
2
  import json
3
  import os
 
4
  import random
5
+ import re
6
 
7
 
8
  def get_args():
9
  parser = argparse.ArgumentParser()
10
+ parser.add_argument("--base-dir", type=str)
11
+ parser.add_argument("--result-file", type=str)
12
+ parser.add_argument("--output-file", type=str)
13
+ parser.add_argument("--output-result", type=str)
14
+ parser.add_argument("--split", type=str, default="test")
15
+ parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
16
  return parser.parse_args()
17
 
18
 
19
  def convert_caps(results):
20
  fakecaps = []
21
  for result in results:
22
+ image_id = result["question_id"]
23
+ caption = result["text"]
24
  fakecaps.append({"image_id": int(image_id), "caption": caption})
25
  return fakecaps
26
 
 
29
  """
30
  Get the index (e.g. 2) from the prediction (e.g. 'C')
31
  """
32
+ if prediction in options[: len(choices)]:
33
  return options.index(prediction)
34
  else:
35
  return random.choice(range(len(choices)))
 
39
  args = get_args()
40
 
41
  base_dir = args.base_dir
42
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
43
+ args.split
44
+ ]
45
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
  predictions = [json.loads(line) for line in open(args.result_file)]
47
+ predictions = {pred["question_id"]: pred for pred in predictions}
48
  split_problems = {idx: problems[idx] for idx in split_indices}
49
 
50
+ results = {"correct": [], "incorrect": []}
51
  sqa_results = {}
52
+ sqa_results["acc"] = None
53
+ sqa_results["correct"] = None
54
+ sqa_results["count"] = None
55
+ sqa_results["results"] = {}
56
+ sqa_results["outputs"] = {}
57
 
58
  for prob_id, prob in split_problems.items():
59
  if prob_id not in predictions:
60
  continue
61
  pred = predictions[prob_id]
62
+ pred_text = pred["text"]
63
 
64
+ pattern = re.compile(r"The answer is ([A-Z]).")
65
  res = pattern.findall(pred_text)
66
  if len(res) == 1:
67
  answer = res[0] # 'A', 'B', ...
68
  else:
69
  answer = "FAILED"
70
 
71
+ pred_idx = get_pred_idx(answer, prob["choices"], args.options)
72
 
73
  analysis = {
74
+ "question_id": prob_id,
75
+ "parsed_ans": answer,
76
+ "ground_truth": args.options[prob["answer"]],
77
+ "question": pred["prompt"],
78
+ "pred": pred_text,
79
+ "is_multimodal": "<image>" in pred["prompt"],
80
  }
81
 
82
+ sqa_results["results"][prob_id] = get_pred_idx(
83
+ answer, prob["choices"], args.options
84
+ )
85
+ sqa_results["outputs"][prob_id] = pred_text
86
 
87
+ if pred_idx == prob["answer"]:
88
+ results["correct"].append(analysis)
89
  else:
90
+ results["incorrect"].append(analysis)
91
 
92
+ correct = len(results["correct"])
93
+ total = len(results["correct"]) + len(results["incorrect"])
94
+ print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
95
 
96
+ sqa_results["acc"] = correct / total * 100
97
+ sqa_results["correct"] = correct
98
+ sqa_results["count"] = total
99
 
100
+ with open(args.output_file, "w") as f:
101
  json.dump(results, f, indent=2)
102
+ with open(args.output_result, "w") as f:
103
  json.dump(sqa_results, f, indent=2)
model/llava/eval/eval_science_qa_gpt4.py CHANGED
@@ -1,26 +1,26 @@
1
  import argparse
2
  import json
3
  import os
4
- import re
5
  import random
 
6
  from collections import defaultdict
7
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
- parser.add_argument('--base-dir', type=str)
12
- parser.add_argument('--gpt4-result', type=str)
13
- parser.add_argument('--our-result', type=str)
14
- parser.add_argument('--split', type=str, default='test')
15
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16
  return parser.parse_args()
17
 
18
 
19
  def convert_caps(results):
20
  fakecaps = []
21
  for result in results:
22
- image_id = result['question_id']
23
- caption = result['text']
24
  fakecaps.append({"image_id": int(image_id), "caption": caption})
25
  return fakecaps
26
 
@@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options):
29
  """
30
  Get the index (e.g. 2) from the prediction (e.g. 'C')
31
  """
32
- if prediction in options[:len(choices)]:
33
  return options.index(prediction)
34
  else:
35
  return random.choice(range(len(choices)))
@@ -39,13 +39,15 @@ if __name__ == "__main__":
39
  args = get_args()
40
 
41
  base_dir = args.base_dir
42
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
 
 
43
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
44
  our_predictions = [json.loads(line) for line in open(args.our_result)]
45
- our_predictions = {pred['question_id']: pred for pred in our_predictions}
46
  split_problems = {idx: problems[idx] for idx in split_indices}
47
 
48
- gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49
 
50
  results = defaultdict(lambda: 0)
51
 
@@ -54,10 +56,10 @@ if __name__ == "__main__":
54
  continue
55
  if prob_id not in gpt4_predictions:
56
  continue
57
- our_pred = our_predictions[prob_id]['text']
58
  gpt4_pred = gpt4_predictions[prob_id]
59
 
60
- pattern = re.compile(r'The answer is ([A-Z]).')
61
  our_res = pattern.findall(our_pred)
62
  if len(our_res) == 1:
63
  our_answer = our_res[0] # 'A', 'B', ...
@@ -69,11 +71,11 @@ if __name__ == "__main__":
69
  else:
70
  gpt4_answer = "FAILED"
71
 
72
- our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73
- gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74
 
75
- if gpt4_answer == 'FAILED':
76
- results['gpt4_failed'] += 1
77
  # continue
78
  gpt4_pred_idx = our_pred_idx
79
  # if our_pred_idx != prob['answer']:
@@ -87,18 +89,20 @@ if __name__ == "__main__":
87
  pass
88
  # gpt4_pred_idx = our_pred_idx
89
 
90
- if gpt4_pred_idx == prob['answer']:
91
- results['correct'] += 1
92
  else:
93
- results['incorrect'] += 1
94
-
95
-
96
- if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97
- results['correct_upperbound'] += 1
98
-
99
- correct = results['correct']
100
- total = results['correct'] + results['incorrect']
101
- print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102
- print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103
- print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104
-
 
 
 
1
  import argparse
2
  import json
3
  import os
 
4
  import random
5
+ import re
6
  from collections import defaultdict
7
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
+ parser.add_argument("--base-dir", type=str)
12
+ parser.add_argument("--gpt4-result", type=str)
13
+ parser.add_argument("--our-result", type=str)
14
+ parser.add_argument("--split", type=str, default="test")
15
+ parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
16
  return parser.parse_args()
17
 
18
 
19
  def convert_caps(results):
20
  fakecaps = []
21
  for result in results:
22
+ image_id = result["question_id"]
23
+ caption = result["text"]
24
  fakecaps.append({"image_id": int(image_id), "caption": caption})
25
  return fakecaps
26
 
 
29
  """
30
  Get the index (e.g. 2) from the prediction (e.g. 'C')
31
  """
32
+ if prediction in options[: len(choices)]:
33
  return options.index(prediction)
34
  else:
35
  return random.choice(range(len(choices)))
 
39
  args = get_args()
40
 
41
  base_dir = args.base_dir
42
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
43
+ args.split
44
+ ]
45
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
  our_predictions = [json.loads(line) for line in open(args.our_result)]
47
+ our_predictions = {pred["question_id"]: pred for pred in our_predictions}
48
  split_problems = {idx: problems[idx] for idx in split_indices}
49
 
50
+ gpt4_predictions = json.load(open(args.gpt4_result))["outputs"]
51
 
52
  results = defaultdict(lambda: 0)
53
 
 
56
  continue
57
  if prob_id not in gpt4_predictions:
58
  continue
59
+ our_pred = our_predictions[prob_id]["text"]
60
  gpt4_pred = gpt4_predictions[prob_id]
61
 
62
+ pattern = re.compile(r"The answer is ([A-Z]).")
63
  our_res = pattern.findall(our_pred)
64
  if len(our_res) == 1:
65
  our_answer = our_res[0] # 'A', 'B', ...
 
71
  else:
72
  gpt4_answer = "FAILED"
73
 
74
+ our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options)
75
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options)
76
 
77
+ if gpt4_answer == "FAILED":
78
+ results["gpt4_failed"] += 1
79
  # continue
80
  gpt4_pred_idx = our_pred_idx
81
  # if our_pred_idx != prob['answer']:
 
89
  pass
90
  # gpt4_pred_idx = our_pred_idx
91
 
92
+ if gpt4_pred_idx == prob["answer"]:
93
+ results["correct"] += 1
94
  else:
95
+ results["incorrect"] += 1
96
+
97
+ if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]:
98
+ results["correct_upperbound"] += 1
99
+
100
+ correct = results["correct"]
101
+ total = results["correct"] + results["incorrect"]
102
+ print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%")
103
+ print(
104
+ f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%'
105
+ )
106
+ print(
107
+ f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%'
108
+ )
model/llava/eval/eval_science_qa_gpt4_requery.py CHANGED
@@ -1,28 +1,28 @@
1
  import argparse
2
  import json
3
  import os
4
- import re
5
  import random
 
6
  from collections import defaultdict
7
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
- parser.add_argument('--base-dir', type=str)
12
- parser.add_argument('--gpt4-result', type=str)
13
- parser.add_argument('--requery-result', type=str)
14
- parser.add_argument('--our-result', type=str)
15
- parser.add_argument('--output-result', type=str)
16
- parser.add_argument('--split', type=str, default='test')
17
- parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18
  return parser.parse_args()
19
 
20
 
21
  def convert_caps(results):
22
  fakecaps = []
23
  for result in results:
24
- image_id = result['question_id']
25
- caption = result['text']
26
  fakecaps.append({"image_id": int(image_id), "caption": caption})
27
  return fakecaps
28
 
@@ -31,7 +31,7 @@ def get_pred_idx(prediction, choices, options):
31
  """
32
  Get the index (e.g. 2) from the prediction (e.g. 'C')
33
  """
34
- if prediction in options[:len(choices)]:
35
  return options.index(prediction)
36
  else:
37
  return random.choice(range(len(choices)))
@@ -41,40 +41,42 @@ if __name__ == "__main__":
41
  args = get_args()
42
 
43
  base_dir = args.base_dir
44
- split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
 
 
45
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
46
  our_predictions = [json.loads(line) for line in open(args.our_result)]
47
- our_predictions = {pred['question_id']: pred for pred in our_predictions}
48
  split_problems = {idx: problems[idx] for idx in split_indices}
49
 
50
  requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51
- requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52
 
53
- gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54
 
55
  results = defaultdict(lambda: 0)
56
 
57
  sqa_results = {}
58
- sqa_results['acc'] = None
59
- sqa_results['correct'] = None
60
- sqa_results['count'] = None
61
- sqa_results['results'] = {}
62
- sqa_results['outputs'] = {}
63
 
64
  for prob_id, prob in split_problems.items():
65
  if prob_id not in our_predictions:
66
  assert False
67
  if prob_id not in gpt4_predictions:
68
  assert False
69
- our_pred = our_predictions[prob_id]['text']
70
  gpt4_pred = gpt4_predictions[prob_id]
71
  if prob_id not in requery_predictions:
72
- results['missing_requery'] += 1
73
  requery_pred = "MISSING"
74
  else:
75
- requery_pred = requery_predictions[prob_id]['text']
76
 
77
- pattern = re.compile(r'The answer is ([A-Z]).')
78
  our_res = pattern.findall(our_pred)
79
  if len(our_res) == 1:
80
  our_answer = our_res[0] # 'A', 'B', ...
@@ -93,57 +95,70 @@ if __name__ == "__main__":
93
  else:
94
  gpt4_answer = "FAILED"
95
 
96
- our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97
- gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98
- requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99
-
100
- results['total'] += 1
101
-
102
- if gpt4_answer == 'FAILED':
103
- results['gpt4_failed'] += 1
104
- if gpt4_pred_idx == prob['answer']:
105
- results['gpt4_correct'] += 1
106
- if our_pred_idx == prob['answer']:
107
- results['gpt4_ourvisual_correct'] += 1
108
- elif gpt4_pred_idx == prob['answer']:
109
- results['gpt4_correct'] += 1
110
- results['gpt4_ourvisual_correct'] += 1
111
-
112
- if our_pred_idx == prob['answer']:
113
- results['our_correct'] += 1
114
-
115
- if requery_answer == 'FAILED':
116
- sqa_results['results'][prob_id] = our_pred_idx
117
- if our_pred_idx == prob['answer']:
118
- results['requery_correct'] += 1
119
  else:
120
- sqa_results['results'][prob_id] = requery_pred_idx
121
- if requery_pred_idx == prob['answer']:
122
- results['requery_correct'] += 1
123
  else:
124
- print(f"""
 
125
  Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126
  Our ({our_answer}): {our_pred}
127
  GPT-4 ({gpt4_answer}): {gpt4_pred}
128
  Requery ({requery_answer}): {requery_pred}
129
  print("=====================================")
130
- """)
131
-
132
- if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133
- results['correct_upperbound'] += 1
134
-
135
- total = results['total']
136
- print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137
- print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138
- print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139
- print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140
- print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141
- print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142
-
143
- sqa_results['acc'] = results["requery_correct"] / total * 100
144
- sqa_results['correct'] = results["requery_correct"]
145
- sqa_results['count'] = total
146
-
147
- with open(args.output_result, 'w') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  json.dump(sqa_results, f, indent=2)
149
-
 
1
  import argparse
2
  import json
3
  import os
 
4
  import random
5
+ import re
6
  from collections import defaultdict
7
 
8
 
9
  def get_args():
10
  parser = argparse.ArgumentParser()
11
+ parser.add_argument("--base-dir", type=str)
12
+ parser.add_argument("--gpt4-result", type=str)
13
+ parser.add_argument("--requery-result", type=str)
14
+ parser.add_argument("--our-result", type=str)
15
+ parser.add_argument("--output-result", type=str)
16
+ parser.add_argument("--split", type=str, default="test")
17
+ parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"])
18
  return parser.parse_args()
19
 
20
 
21
  def convert_caps(results):
22
  fakecaps = []
23
  for result in results:
24
+ image_id = result["question_id"]
25
+ caption = result["text"]
26
  fakecaps.append({"image_id": int(image_id), "caption": caption})
27
  return fakecaps
28
 
 
31
  """
32
  Get the index (e.g. 2) from the prediction (e.g. 'C')
33
  """
34
+ if prediction in options[: len(choices)]:
35
  return options.index(prediction)
36
  else:
37
  return random.choice(range(len(choices)))
 
41
  args = get_args()
42
 
43
  base_dir = args.base_dir
44
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[
45
+ args.split
46
+ ]
47
  problems = json.load(open(os.path.join(base_dir, "problems.json")))
48
  our_predictions = [json.loads(line) for line in open(args.our_result)]
49
+ our_predictions = {pred["question_id"]: pred for pred in our_predictions}
50
  split_problems = {idx: problems[idx] for idx in split_indices}
51
 
52
  requery_predictions = [json.loads(line) for line in open(args.requery_result)]
53
+ requery_predictions = {pred["question_id"]: pred for pred in requery_predictions}
54
 
55
+ gpt4_predictions = json.load(open(args.gpt4_result))["outputs"]
56
 
57
  results = defaultdict(lambda: 0)
58
 
59
  sqa_results = {}
60
+ sqa_results["acc"] = None
61
+ sqa_results["correct"] = None
62
+ sqa_results["count"] = None
63
+ sqa_results["results"] = {}
64
+ sqa_results["outputs"] = {}
65
 
66
  for prob_id, prob in split_problems.items():
67
  if prob_id not in our_predictions:
68
  assert False
69
  if prob_id not in gpt4_predictions:
70
  assert False
71
+ our_pred = our_predictions[prob_id]["text"]
72
  gpt4_pred = gpt4_predictions[prob_id]
73
  if prob_id not in requery_predictions:
74
+ results["missing_requery"] += 1
75
  requery_pred = "MISSING"
76
  else:
77
+ requery_pred = requery_predictions[prob_id]["text"]
78
 
79
+ pattern = re.compile(r"The answer is ([A-Z]).")
80
  our_res = pattern.findall(our_pred)
81
  if len(our_res) == 1:
82
  our_answer = our_res[0] # 'A', 'B', ...
 
95
  else:
96
  gpt4_answer = "FAILED"
97
 
98
+ our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options)
99
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options)
100
+ requery_pred_idx = get_pred_idx(requery_answer, prob["choices"], args.options)
101
+
102
+ results["total"] += 1
103
+
104
+ if gpt4_answer == "FAILED":
105
+ results["gpt4_failed"] += 1
106
+ if gpt4_pred_idx == prob["answer"]:
107
+ results["gpt4_correct"] += 1
108
+ if our_pred_idx == prob["answer"]:
109
+ results["gpt4_ourvisual_correct"] += 1
110
+ elif gpt4_pred_idx == prob["answer"]:
111
+ results["gpt4_correct"] += 1
112
+ results["gpt4_ourvisual_correct"] += 1
113
+
114
+ if our_pred_idx == prob["answer"]:
115
+ results["our_correct"] += 1
116
+
117
+ if requery_answer == "FAILED":
118
+ sqa_results["results"][prob_id] = our_pred_idx
119
+ if our_pred_idx == prob["answer"]:
120
+ results["requery_correct"] += 1
121
  else:
122
+ sqa_results["results"][prob_id] = requery_pred_idx
123
+ if requery_pred_idx == prob["answer"]:
124
+ results["requery_correct"] += 1
125
  else:
126
+ print(
127
+ f"""
128
  Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
129
  Our ({our_answer}): {our_pred}
130
  GPT-4 ({gpt4_answer}): {gpt4_pred}
131
  Requery ({requery_answer}): {requery_pred}
132
  print("=====================================")
133
+ """
134
+ )
135
+
136
+ if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]:
137
+ results["correct_upperbound"] += 1
138
+
139
+ total = results["total"]
140
+ print(
141
+ f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%'
142
+ )
143
+ print(
144
+ f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%'
145
+ )
146
+ print(
147
+ f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%'
148
+ )
149
+ print(
150
+ f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%'
151
+ )
152
+ print(
153
+ f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%'
154
+ )
155
+ print(
156
+ f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%'
157
+ )
158
+
159
+ sqa_results["acc"] = results["requery_correct"] / total * 100
160
+ sqa_results["correct"] = results["requery_correct"]
161
+ sqa_results["count"] = total
162
+
163
+ with open(args.output_result, "w") as f:
164
  json.dump(sqa_results, f, indent=2)
 
model/llava/eval/generate_webpage_data_from_table.py CHANGED
@@ -4,10 +4,10 @@ import os
4
  import re
5
 
6
  # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
- models = ['vicuna']
8
 
9
 
10
- def read_jsonl(path: str, key: str=None):
11
  data = []
12
  with open(os.path.expanduser(path)) as f:
13
  for line in f:
@@ -23,21 +23,27 @@ def read_jsonl(path: str, key: str=None):
23
  def trim_hanging_lines(s: str, n: int) -> str:
24
  s = s.strip()
25
  for _ in range(n):
26
- s = s.split('\n', 1)[1].strip()
27
  return s
28
 
29
 
30
- if __name__ == '__main__':
31
- questions = read_jsonl('table/question.jsonl', key='question_id')
32
 
33
  # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
  # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
  # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
  # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
- vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38
- ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
 
 
 
 
39
 
40
- review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
 
 
41
  # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42
  # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43
  # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
@@ -46,26 +52,26 @@ if __name__ == '__main__':
46
  records = []
47
  for qid in questions.keys():
48
  r = {
49
- 'id': qid,
50
- 'category': questions[qid]['category'],
51
- 'question': questions[qid]['text'],
52
- 'answers': {
53
  # 'alpaca': alpaca_answers[qid]['text'],
54
  # 'llama': llama_answers[qid]['text'],
55
  # 'bard': bard_answers[qid]['text'],
56
  # 'gpt35': gpt35_answers[qid]['text'],
57
- 'vicuna': vicuna_answers[qid]['text'],
58
- 'ours': ours_answers[qid]['text'],
59
  },
60
- 'evaluations': {
61
  # 'alpaca': review_alpaca[qid]['text'],
62
  # 'llama': review_llama[qid]['text'],
63
  # 'bard': review_bard[qid]['text'],
64
- 'vicuna': review_vicuna[qid]['content'],
65
  # 'gpt35': review_gpt35[qid]['text'],
66
  },
67
- 'scores': {
68
- 'vicuna': review_vicuna[qid]['tuple'],
69
  # 'alpaca': review_alpaca[qid]['score'],
70
  # 'llama': review_llama[qid]['score'],
71
  # 'bard': review_bard[qid]['score'],
@@ -75,37 +81,39 @@ if __name__ == '__main__':
75
 
76
  # cleanup data
77
  cleaned_evals = {}
78
- for k, v in r['evaluations'].items():
79
  v = v.strip()
80
- lines = v.split('\n')
81
  # trim the first line if it's a pair of numbers
82
- if re.match(r'\d+[, ]+\d+', lines[0]):
83
  lines = lines[1:]
84
- v = '\n'.join(lines)
85
- cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
 
 
86
 
87
- r['evaluations'] = cleaned_evals
88
  records.append(r)
89
 
90
  # Reorder the records, this is optional
91
  for r in records:
92
- if r['id'] <= 20:
93
- r['id'] += 60
94
  else:
95
- r['id'] -= 20
96
  for r in records:
97
- if r['id'] <= 50:
98
- r['id'] += 10
99
- elif 50 < r['id'] <= 60:
100
- r['id'] -= 50
101
  for r in records:
102
- if r['id'] == 7:
103
- r['id'] = 1
104
- elif r['id'] < 7:
105
- r['id'] += 1
106
 
107
- records.sort(key=lambda x: x['id'])
108
 
109
  # Write to file
110
- with open('webpage/data.json', 'w') as f:
111
- json.dump({'questions': records, 'models': models}, f, indent=2)
 
4
  import re
5
 
6
  # models = ['llama', 'alpaca', 'gpt35', 'bard']
7
+ models = ["vicuna"]
8
 
9
 
10
+ def read_jsonl(path: str, key: str = None):
11
  data = []
12
  with open(os.path.expanduser(path)) as f:
13
  for line in f:
 
23
  def trim_hanging_lines(s: str, n: int) -> str:
24
  s = s.strip()
25
  for _ in range(n):
26
+ s = s.split("\n", 1)[1].strip()
27
  return s
28
 
29
 
30
+ if __name__ == "__main__":
31
+ questions = read_jsonl("table/question.jsonl", key="question_id")
32
 
33
  # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34
  # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35
  # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36
  # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37
+ vicuna_answers = read_jsonl(
38
+ "table/answer/answer_vicuna-13b.jsonl", key="question_id"
39
+ )
40
+ ours_answers = read_jsonl(
41
+ "table/results/llama-13b-hf-alpaca.jsonl", key="question_id"
42
+ )
43
 
44
+ review_vicuna = read_jsonl(
45
+ "table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl", key="question_id"
46
+ )
47
  # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
48
  # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
49
  # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
 
52
  records = []
53
  for qid in questions.keys():
54
  r = {
55
+ "id": qid,
56
+ "category": questions[qid]["category"],
57
+ "question": questions[qid]["text"],
58
+ "answers": {
59
  # 'alpaca': alpaca_answers[qid]['text'],
60
  # 'llama': llama_answers[qid]['text'],
61
  # 'bard': bard_answers[qid]['text'],
62
  # 'gpt35': gpt35_answers[qid]['text'],
63
+ "vicuna": vicuna_answers[qid]["text"],
64
+ "ours": ours_answers[qid]["text"],
65
  },
66
+ "evaluations": {
67
  # 'alpaca': review_alpaca[qid]['text'],
68
  # 'llama': review_llama[qid]['text'],
69
  # 'bard': review_bard[qid]['text'],
70
+ "vicuna": review_vicuna[qid]["content"],
71
  # 'gpt35': review_gpt35[qid]['text'],
72
  },
73
+ "scores": {
74
+ "vicuna": review_vicuna[qid]["tuple"],
75
  # 'alpaca': review_alpaca[qid]['score'],
76
  # 'llama': review_llama[qid]['score'],
77
  # 'bard': review_bard[qid]['score'],
 
81
 
82
  # cleanup data
83
  cleaned_evals = {}
84
+ for k, v in r["evaluations"].items():
85
  v = v.strip()
86
+ lines = v.split("\n")
87
  # trim the first line if it's a pair of numbers
88
+ if re.match(r"\d+[, ]+\d+", lines[0]):
89
  lines = lines[1:]
90
+ v = "\n".join(lines)
91
+ cleaned_evals[k] = v.replace("Assistant 1", "**Assistant 1**").replace(
92
+ "Assistant 2", "**Assistant 2**"
93
+ )
94
 
95
+ r["evaluations"] = cleaned_evals
96
  records.append(r)
97
 
98
  # Reorder the records, this is optional
99
  for r in records:
100
+ if r["id"] <= 20:
101
+ r["id"] += 60
102
  else:
103
+ r["id"] -= 20
104
  for r in records:
105
+ if r["id"] <= 50:
106
+ r["id"] += 10
107
+ elif 50 < r["id"] <= 60:
108
+ r["id"] -= 50
109
  for r in records:
110
+ if r["id"] == 7:
111
+ r["id"] = 1
112
+ elif r["id"] < 7:
113
+ r["id"] += 1
114
 
115
+ records.sort(key=lambda x: x["id"])
116
 
117
  # Write to file
118
+ with open("webpage/data.json", "w") as f:
119
+ json.dump({"questions": records, "models": models}, f, indent=2)
model/llava/eval/model_qa.py CHANGED
@@ -1,13 +1,13 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3
- import torch
4
- import os
5
  import json
6
- from tqdm import tqdm
7
- import shortuuid
8
 
 
 
9
  from llava.conversation import default_conversation
10
  from llava.utils import disable_torch_init
 
 
11
 
12
 
13
  # new stopping implementation
@@ -18,11 +18,15 @@ class KeywordsStoppingCriteria(StoppingCriteria):
18
  self.start_len = None
19
  self.input_ids = input_ids
20
 
21
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
22
  if self.start_len is None:
23
  self.start_len = self.input_ids.shape[1]
24
  else:
25
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
 
 
26
  for keyword in self.keywords:
27
  if keyword in outputs:
28
  return True
@@ -35,9 +39,9 @@ def eval_model(model_name, questions_file, answers_file):
35
  disable_torch_init()
36
  model_name = os.path.expanduser(model_name)
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
- model = AutoModelForCausalLM.from_pretrained(model_name,
39
- torch_dtype=torch.float16).cuda()
40
-
41
 
42
  ques_file = open(os.path.expanduser(questions_file), "r")
43
  ans_file = open(os.path.expanduser(answers_file), "w")
@@ -56,7 +60,8 @@ def eval_model(model_name, questions_file, answers_file):
56
  do_sample=True,
57
  temperature=0.7,
58
  max_new_tokens=1024,
59
- stopping_criteria=[stopping_criteria])
 
60
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
61
  try:
62
  index = outputs.index(conv.sep, len(prompt))
@@ -64,16 +69,24 @@ def eval_model(model_name, questions_file, answers_file):
64
  outputs += conv.sep
65
  index = outputs.index(conv.sep, len(prompt))
66
 
67
- outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
68
  ans_id = shortuuid.uuid()
69
- ans_file.write(json.dumps({"question_id": idx,
70
- "text": outputs,
71
- "answer_id": ans_id,
72
- "model_id": model_name,
73
- "metadata": {}}) + "\n")
 
 
 
 
 
 
 
74
  ans_file.flush()
75
  ans_file.close()
76
 
 
77
  if __name__ == "__main__":
78
  parser = argparse.ArgumentParser()
79
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
 
1
  import argparse
 
 
 
2
  import json
3
+ import os
 
4
 
5
+ import shortuuid
6
+ import torch
7
  from llava.conversation import default_conversation
8
  from llava.utils import disable_torch_init
9
+ from tqdm import tqdm
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
11
 
12
 
13
  # new stopping implementation
 
18
  self.start_len = None
19
  self.input_ids = input_ids
20
 
21
+ def __call__(
22
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
23
+ ) -> bool:
24
  if self.start_len is None:
25
  self.start_len = self.input_ids.shape[1]
26
  else:
27
+ outputs = self.tokenizer.batch_decode(
28
+ output_ids[:, self.start_len :], skip_special_tokens=True
29
+ )[0]
30
  for keyword in self.keywords:
31
  if keyword in outputs:
32
  return True
 
39
  disable_torch_init()
40
  model_name = os.path.expanduser(model_name)
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name, torch_dtype=torch.float16
44
+ ).cuda()
45
 
46
  ques_file = open(os.path.expanduser(questions_file), "r")
47
  ans_file = open(os.path.expanduser(answers_file), "w")
 
60
  do_sample=True,
61
  temperature=0.7,
62
  max_new_tokens=1024,
63
+ stopping_criteria=[stopping_criteria],
64
+ )
65
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
66
  try:
67
  index = outputs.index(conv.sep, len(prompt))
 
69
  outputs += conv.sep
70
  index = outputs.index(conv.sep, len(prompt))
71
 
72
+ outputs = outputs[len(prompt) + len(conv.roles[1]) + 2 : index].strip()
73
  ans_id = shortuuid.uuid()
74
+ ans_file.write(
75
+ json.dumps(
76
+ {
77
+ "question_id": idx,
78
+ "text": outputs,
79
+ "answer_id": ans_id,
80
+ "model_id": model_name,
81
+ "metadata": {},
82
+ }
83
+ )
84
+ + "\n"
85
+ )
86
  ans_file.flush()
87
  ans_file.close()
88
 
89
+
90
  if __name__ == "__main__":
91
  parser = argparse.ArgumentParser()
92
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
model/llava/eval/model_vqa.py CHANGED
@@ -1,25 +1,25 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
3
- import torch
4
- import os
5
  import json
6
- from tqdm import tqdm
7
- import shortuuid
 
8
 
 
 
9
  from llava import LlavaLlamaForCausalLM
10
  from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
12
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
13
-
14
  from PIL import Image
15
- import random
16
- import math
 
 
17
 
18
 
19
  def split_list(lst, n):
20
  """Split a list into n (roughly) equal-sized chunks"""
21
  chunk_size = math.ceil(len(lst) / n) # integer division
22
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
 
24
 
25
  def get_chunk(lst, n, k):
@@ -37,12 +37,14 @@ def patch_config(config):
37
  patch_dict = {
38
  "use_mm_proj": True,
39
  "mm_vision_tower": "openai/clip-vit-large-patch14",
40
- "mm_hidden_size": 1024
41
  }
42
 
43
  cfg = AutoConfig.from_pretrained(config)
44
  if not hasattr(cfg, "mm_vision_tower"):
45
- print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
 
 
46
  for k, v in patch_dict.items():
47
  setattr(cfg, k, v)
48
  cfg.save_pretrained(config)
@@ -55,50 +57,84 @@ def eval_model(args):
55
  tokenizer = AutoTokenizer.from_pretrained(model_name)
56
  if args.mm_projector is None:
57
  patch_config(model_name)
58
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
59
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
60
 
61
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
62
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
63
  if mm_use_im_start_end:
64
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
65
 
66
  vision_tower = model.model.vision_tower[0]
67
- vision_tower.to(device='cuda', dtype=torch.float16)
68
  vision_config = vision_tower.config
69
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
70
  vision_config.use_im_start_end = mm_use_im_start_end
71
  if mm_use_im_start_end:
72
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
73
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
74
  else:
75
  # in case of using a pretrained model with only a MLP projector weights
76
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
 
 
77
 
78
- vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda()
79
- image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
 
 
 
 
80
 
81
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
82
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
83
  if mm_use_im_start_end:
84
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
85
 
86
  vision_config = vision_tower.config
87
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
88
  vision_config.use_im_start_end = mm_use_im_start_end
89
  if mm_use_im_start_end:
90
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
91
 
92
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
93
 
94
- mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size)
95
- mm_projector_weights = torch.load(args.mm_projector, map_location='cpu')
96
- mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
 
 
 
 
97
 
98
  model.model.mm_projector = mm_projector.cuda().half()
99
  model.model.vision_tower = [vision_tower]
100
 
101
- questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
 
 
102
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
103
  answers_file = os.path.expanduser(args.answers_file)
104
  os.makedirs(os.path.dirname(answers_file), exist_ok=True)
@@ -109,12 +145,18 @@ def eval_model(args):
109
  qs = line["text"]
110
  cur_prompt = qs
111
  if mm_use_im_start_end:
112
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
113
  else:
114
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
115
 
116
- if args.conv_mode == 'simple_legacy':
117
- qs += '\n\n### Response:'
118
  # conv = default_conversation.copy()
119
  conv = conv_templates[args.conv_mode].copy()
120
  conv.append_message(conv.roles[0], qs)
@@ -123,7 +165,9 @@ def eval_model(args):
123
 
124
  image = Image.open(os.path.join(args.image_folder, image_file))
125
  # image.save(os.path.join(save_image_folder, image_file))
126
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
127
 
128
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
129
 
@@ -135,17 +179,21 @@ def eval_model(args):
135
  self.start_len = None
136
  self.input_ids = input_ids
137
 
138
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
139
  if self.start_len is None:
140
  self.start_len = self.input_ids.shape[1]
141
  else:
142
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
 
 
143
  for keyword in self.keywords:
144
  if keyword in outputs:
145
  return True
146
  return False
147
 
148
- keywords = ['###']
149
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
150
 
151
  with torch.inference_mode():
@@ -155,21 +203,28 @@ def eval_model(args):
155
  do_sample=True,
156
  temperature=0.7,
157
  max_new_tokens=1024,
158
- stopping_criteria=[stopping_criteria])
 
159
 
160
  input_token_len = input_ids.shape[1]
161
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
162
  if n_diff_input_output > 0:
163
- print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
164
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
165
-
166
- if args.conv_mode == 'simple_legacy' or args.conv_mode == 'simple':
 
 
 
 
167
  while True:
168
  cur_len = len(outputs)
169
  outputs = outputs.strip()
170
- for pattern in ['###', 'Assistant:', 'Response:']:
171
  if outputs.startswith(pattern):
172
- outputs = outputs[len(pattern):].strip()
173
  if len(outputs) == cur_len:
174
  break
175
 
@@ -182,15 +237,23 @@ def eval_model(args):
182
  outputs = outputs[:index].strip()
183
 
184
  ans_id = shortuuid.uuid()
185
- ans_file.write(json.dumps({"question_id": idx,
186
- "prompt": cur_prompt,
187
- "text": outputs,
188
- "answer_id": ans_id,
189
- "model_id": model_name,
190
- "metadata": {}}) + "\n")
 
 
 
 
 
 
 
191
  ans_file.flush()
192
  ans_file.close()
193
 
 
194
  if __name__ == "__main__":
195
  parser = argparse.ArgumentParser()
196
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
 
1
  import argparse
 
 
 
2
  import json
3
+ import math
4
+ import os
5
+ import random
6
 
7
+ import shortuuid
8
+ import torch
9
  from llava import LlavaLlamaForCausalLM
10
  from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
 
 
12
  from PIL import Image
13
+ from tqdm import tqdm
14
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
15
+ CLIPImageProcessor, CLIPVisionModel,
16
+ StoppingCriteria)
17
 
18
 
19
  def split_list(lst, n):
20
  """Split a list into n (roughly) equal-sized chunks"""
21
  chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
23
 
24
 
25
  def get_chunk(lst, n, k):
 
37
  patch_dict = {
38
  "use_mm_proj": True,
39
  "mm_vision_tower": "openai/clip-vit-large-patch14",
40
+ "mm_hidden_size": 1024,
41
  }
42
 
43
  cfg = AutoConfig.from_pretrained(config)
44
  if not hasattr(cfg, "mm_vision_tower"):
45
+ print(
46
+ f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk."
47
+ )
48
  for k, v in patch_dict.items():
49
  setattr(cfg, k, v)
50
  cfg.save_pretrained(config)
 
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  if args.mm_projector is None:
59
  patch_config(model_name)
60
+ model = LlavaLlamaForCausalLM.from_pretrained(
61
+ model_name, torch_dtype=torch.float16
62
+ ).cuda()
63
+ image_processor = CLIPImageProcessor.from_pretrained(
64
+ model.config.mm_vision_tower, torch_dtype=torch.float16
65
+ )
66
 
67
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
68
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
69
  if mm_use_im_start_end:
70
+ tokenizer.add_tokens(
71
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
72
+ )
73
 
74
  vision_tower = model.model.vision_tower[0]
75
+ vision_tower.to(device="cuda", dtype=torch.float16)
76
  vision_config = vision_tower.config
77
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
78
+ [DEFAULT_IMAGE_PATCH_TOKEN]
79
+ )[0]
80
  vision_config.use_im_start_end = mm_use_im_start_end
81
  if mm_use_im_start_end:
82
+ (
83
+ vision_config.im_start_token,
84
+ vision_config.im_end_token,
85
+ ) = tokenizer.convert_tokens_to_ids(
86
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
87
+ )
88
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
89
  else:
90
  # in case of using a pretrained model with only a MLP projector weights
91
+ model = LlavaLlamaForCausalLM.from_pretrained(
92
+ model_name, torch_dtype=torch.float16
93
+ ).cuda()
94
 
95
+ vision_tower = CLIPVisionModel.from_pretrained(
96
+ args.vision_tower, torch_dtype=torch.float16
97
+ ).cuda()
98
+ image_processor = CLIPImageProcessor.from_pretrained(
99
+ args.vision_tower, torch_dtype=torch.float16
100
+ )
101
 
102
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
103
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
104
  if mm_use_im_start_end:
105
+ tokenizer.add_tokens(
106
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
107
+ )
108
 
109
  vision_config = vision_tower.config
110
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
111
+ [DEFAULT_IMAGE_PATCH_TOKEN]
112
+ )[0]
113
  vision_config.use_im_start_end = mm_use_im_start_end
114
  if mm_use_im_start_end:
115
+ (
116
+ vision_config.im_start_token,
117
+ vision_config.im_end_token,
118
+ ) = tokenizer.convert_tokens_to_ids(
119
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
120
+ )
121
 
122
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
123
 
124
+ mm_projector = torch.nn.Linear(
125
+ vision_config.hidden_size, model.config.hidden_size
126
+ )
127
+ mm_projector_weights = torch.load(args.mm_projector, map_location="cpu")
128
+ mm_projector.load_state_dict(
129
+ {k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
130
+ )
131
 
132
  model.model.mm_projector = mm_projector.cuda().half()
133
  model.model.vision_tower = [vision_tower]
134
 
135
+ questions = [
136
+ json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")
137
+ ]
138
  questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
139
  answers_file = os.path.expanduser(args.answers_file)
140
  os.makedirs(os.path.dirname(answers_file), exist_ok=True)
 
145
  qs = line["text"]
146
  cur_prompt = qs
147
  if mm_use_im_start_end:
148
+ qs = (
149
+ qs
150
+ + "\n"
151
+ + DEFAULT_IM_START_TOKEN
152
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
153
+ + DEFAULT_IM_END_TOKEN
154
+ )
155
  else:
156
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
157
 
158
+ if args.conv_mode == "simple_legacy":
159
+ qs += "\n\n### Response:"
160
  # conv = default_conversation.copy()
161
  conv = conv_templates[args.conv_mode].copy()
162
  conv.append_message(conv.roles[0], qs)
 
165
 
166
  image = Image.open(os.path.join(args.image_folder, image_file))
167
  # image.save(os.path.join(save_image_folder, image_file))
168
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
169
+ "pixel_values"
170
+ ][0]
171
 
172
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
173
 
 
179
  self.start_len = None
180
  self.input_ids = input_ids
181
 
182
+ def __call__(
183
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
184
+ ) -> bool:
185
  if self.start_len is None:
186
  self.start_len = self.input_ids.shape[1]
187
  else:
188
+ outputs = self.tokenizer.batch_decode(
189
+ output_ids[:, self.start_len :], skip_special_tokens=True
190
+ )[0]
191
  for keyword in self.keywords:
192
  if keyword in outputs:
193
  return True
194
  return False
195
 
196
+ keywords = ["###"]
197
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
198
 
199
  with torch.inference_mode():
 
203
  do_sample=True,
204
  temperature=0.7,
205
  max_new_tokens=1024,
206
+ stopping_criteria=[stopping_criteria],
207
+ )
208
 
209
  input_token_len = input_ids.shape[1]
210
+ n_diff_input_output = (
211
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
212
+ )
213
  if n_diff_input_output > 0:
214
+ print(
215
+ f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
216
+ )
217
+ outputs = tokenizer.batch_decode(
218
+ output_ids[:, input_token_len:], skip_special_tokens=True
219
+ )[0]
220
+
221
+ if args.conv_mode == "simple_legacy" or args.conv_mode == "simple":
222
  while True:
223
  cur_len = len(outputs)
224
  outputs = outputs.strip()
225
+ for pattern in ["###", "Assistant:", "Response:"]:
226
  if outputs.startswith(pattern):
227
+ outputs = outputs[len(pattern) :].strip()
228
  if len(outputs) == cur_len:
229
  break
230
 
 
237
  outputs = outputs[:index].strip()
238
 
239
  ans_id = shortuuid.uuid()
240
+ ans_file.write(
241
+ json.dumps(
242
+ {
243
+ "question_id": idx,
244
+ "prompt": cur_prompt,
245
+ "text": outputs,
246
+ "answer_id": ans_id,
247
+ "model_id": model_name,
248
+ "metadata": {},
249
+ }
250
+ )
251
+ + "\n"
252
+ )
253
  ans_file.flush()
254
  ans_file.close()
255
 
256
+
257
  if __name__ == "__main__":
258
  parser = argparse.ArgumentParser()
259
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
model/llava/eval/model_vqa_science.py CHANGED
@@ -1,25 +1,25 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
3
- import torch
4
- import os
5
  import json
6
- from tqdm import tqdm
7
- import shortuuid
 
8
 
 
 
9
  from llava import LlavaLlamaForCausalLM
10
  from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
12
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
13
-
14
  from PIL import Image
15
- import random
16
- import math
 
 
17
 
18
 
19
  def split_list(lst, n):
20
  """Split a list into n (roughly) equal-sized chunks"""
21
  chunk_size = math.ceil(len(lst) / n) # integer division
22
- return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23
 
24
 
25
  def get_chunk(lst, n, k):
@@ -33,8 +33,6 @@ DEFAULT_IM_START_TOKEN = "<im_start>"
33
  DEFAULT_IM_END_TOKEN = "<im_end>"
34
 
35
 
36
-
37
-
38
  detail_describe_instructions = [
39
  "Describe the following image in detail.",
40
  "Provide a detailed description of the given image.",
@@ -70,19 +68,21 @@ concise_describe_instructions = [
70
 
71
  prompt_pool = detail_describe_instructions + concise_describe_instructions
72
 
73
- prompt_pool = [ "Describe the following image in detail."]
74
 
75
 
76
  def patch_config(config):
77
  patch_dict = {
78
  "use_mm_proj": True,
79
  "mm_vision_tower": "openai/clip-vit-large-patch14",
80
- "mm_hidden_size": 1024
81
  }
82
 
83
  cfg = AutoConfig.from_pretrained(config)
84
  if not hasattr(cfg, "mm_vision_tower"):
85
- print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.')
 
 
86
  for k, v in patch_dict.items():
87
  setattr(cfg, k, v)
88
  cfg.save_pretrained(config)
@@ -96,11 +96,15 @@ class KeywordsStoppingCriteria(StoppingCriteria):
96
  self.start_len = None
97
  self.input_ids = input_ids
98
 
99
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
100
  if self.start_len is None:
101
  self.start_len = self.input_ids.shape[1]
102
  else:
103
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
 
 
104
  for keyword in self.keywords:
105
  if keyword in outputs:
106
  return True
@@ -114,45 +118,77 @@ def eval_model(args):
114
  tokenizer = AutoTokenizer.from_pretrained(model_name)
115
  if args.mm_projector is None:
116
  patch_config(model_name)
117
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
118
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
119
 
120
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
121
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
122
  if mm_use_im_start_end:
123
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
124
 
125
  vision_tower = model.model.vision_tower[0]
126
- vision_tower.to(device='cuda', dtype=torch.float16)
127
  vision_config = vision_tower.config
128
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
129
  vision_config.use_im_start_end = mm_use_im_start_end
130
  if mm_use_im_start_end:
131
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
132
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
133
  else:
134
  # in case of using a pretrained model with only a MLP projector weights
135
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
 
 
136
 
137
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
138
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
139
  if mm_use_im_start_end:
140
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
141
 
142
- vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda()
143
- image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
 
 
 
 
144
 
145
  vision_config = vision_tower.config
146
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
147
  vision_config.use_im_start_end = mm_use_im_start_end
148
  if mm_use_im_start_end:
149
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
150
 
151
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
152
 
153
- mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size)
154
- mm_projector_weights = torch.load(args.mm_projector, map_location='cpu')
155
- mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
 
 
 
 
156
 
157
  model.model.mm_projector = mm_projector.cuda().half()
158
  model.model.vision_tower = [vision_tower]
@@ -163,33 +199,43 @@ def eval_model(args):
163
  os.makedirs(os.path.dirname(answers_file), exist_ok=True)
164
  os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True)
165
  ans_file = open(answers_file, "w")
166
- save_image_folder = os.path.join(os.path.dirname(os.path.expanduser(args.answers_file)), "images")
 
 
167
  for i, line in enumerate(tqdm(questions)):
168
  idx = line["id"]
169
- question = line['conversations'][0]
170
  gt_ans = line["conversations"][1]
171
-
172
- qs = question['value']
173
 
174
- qs = qs.replace('<image>', '').strip()
 
 
175
  cur_prompt = qs
176
 
177
- if 'image' in line:
178
  image_file = line["image"]
179
  image = Image.open(os.path.join(args.image_folder, image_file))
180
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
181
  images = image_tensor.unsqueeze(0).half().cuda()
182
- if getattr(model.config, 'mm_use_im_start_end', False):
183
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
184
  else:
185
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
186
- cur_prompt = cur_prompt + '\n' + '<image>'
187
  else:
188
  images = None
189
 
190
- if args.conv_mode == 'simple_legacy':
191
- qs += '\n\n### Response:'
192
- assert gt_ans['from'] == 'gpt'
193
  # conv = default_conversation.copy()
194
  conv = conv_templates[args.conv_mode].copy()
195
  conv.append_message(conv.roles[0], qs)
@@ -198,7 +244,7 @@ def eval_model(args):
198
 
199
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
200
 
201
- keywords = ['###']
202
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
203
 
204
  with torch.inference_mode():
@@ -208,22 +254,29 @@ def eval_model(args):
208
  do_sample=True,
209
  temperature=0.7,
210
  max_new_tokens=1024,
211
- stopping_criteria=[stopping_criteria])
 
212
 
213
  # TODO: new implementation
214
  input_token_len = input_ids.shape[1]
215
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
216
  if n_diff_input_output > 0:
217
- print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
218
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
219
-
220
- if args.conv_mode == 'simple_legacy':
 
 
 
 
221
  while True:
222
  cur_len = len(outputs)
223
  outputs = outputs.strip()
224
- for pattern in ['###', 'Assistant:', 'Response:']:
225
  if outputs.startswith(pattern):
226
- outputs = outputs[len(pattern):].strip()
227
  if len(outputs) == cur_len:
228
  break
229
 
@@ -238,11 +291,11 @@ def eval_model(args):
238
  # prompt for answer
239
  if args.answer_prompter:
240
  outputs_reasoning = outputs
241
- inputs = tokenizer([prompt + outputs_reasoning + ' ###\nANSWER:'])
242
 
243
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
244
 
245
- keywords = ['###']
246
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
247
 
248
  with torch.inference_mode():
@@ -252,13 +305,20 @@ def eval_model(args):
252
  do_sample=True,
253
  temperature=0.7,
254
  max_new_tokens=64,
255
- stopping_criteria=[stopping_criteria])
 
256
 
257
  input_token_len = input_ids.shape[1]
258
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
259
  if n_diff_input_output > 0:
260
- print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
261
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
 
 
 
 
262
 
263
  try:
264
  index = outputs.index(conv.sep)
@@ -267,7 +327,7 @@ def eval_model(args):
267
  index = outputs.index(conv.sep)
268
 
269
  outputs = outputs[:index].strip()
270
- outputs = outputs_reasoning + '\n The answer is ' + outputs
271
 
272
  # new implementation ends
273
 
@@ -281,17 +341,24 @@ def eval_model(args):
281
 
282
  # outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
283
 
284
-
285
  ans_id = shortuuid.uuid()
286
- ans_file.write(json.dumps({"question_id": idx,
287
- "prompt": cur_prompt,
288
- "text": outputs,
289
- "answer_id": ans_id,
290
- "model_id": model_name,
291
- "metadata": {}}) + "\n")
 
 
 
 
 
 
 
292
  ans_file.flush()
293
  ans_file.close()
294
 
 
295
  if __name__ == "__main__":
296
  parser = argparse.ArgumentParser()
297
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
 
1
  import argparse
 
 
 
2
  import json
3
+ import math
4
+ import os
5
+ import random
6
 
7
+ import shortuuid
8
+ import torch
9
  from llava import LlavaLlamaForCausalLM
10
  from llava.conversation import conv_templates
11
  from llava.utils import disable_torch_init
 
 
12
  from PIL import Image
13
+ from tqdm import tqdm
14
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
15
+ CLIPImageProcessor, CLIPVisionModel,
16
+ StoppingCriteria)
17
 
18
 
19
  def split_list(lst, n):
20
  """Split a list into n (roughly) equal-sized chunks"""
21
  chunk_size = math.ceil(len(lst) / n) # integer division
22
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
23
 
24
 
25
  def get_chunk(lst, n, k):
 
33
  DEFAULT_IM_END_TOKEN = "<im_end>"
34
 
35
 
 
 
36
  detail_describe_instructions = [
37
  "Describe the following image in detail.",
38
  "Provide a detailed description of the given image.",
 
68
 
69
  prompt_pool = detail_describe_instructions + concise_describe_instructions
70
 
71
+ prompt_pool = ["Describe the following image in detail."]
72
 
73
 
74
  def patch_config(config):
75
  patch_dict = {
76
  "use_mm_proj": True,
77
  "mm_vision_tower": "openai/clip-vit-large-patch14",
78
+ "mm_hidden_size": 1024,
79
  }
80
 
81
  cfg = AutoConfig.from_pretrained(config)
82
  if not hasattr(cfg, "mm_vision_tower"):
83
+ print(
84
+ f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk."
85
+ )
86
  for k, v in patch_dict.items():
87
  setattr(cfg, k, v)
88
  cfg.save_pretrained(config)
 
96
  self.start_len = None
97
  self.input_ids = input_ids
98
 
99
+ def __call__(
100
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
101
+ ) -> bool:
102
  if self.start_len is None:
103
  self.start_len = self.input_ids.shape[1]
104
  else:
105
+ outputs = self.tokenizer.batch_decode(
106
+ output_ids[:, self.start_len :], skip_special_tokens=True
107
+ )[0]
108
  for keyword in self.keywords:
109
  if keyword in outputs:
110
  return True
 
118
  tokenizer = AutoTokenizer.from_pretrained(model_name)
119
  if args.mm_projector is None:
120
  patch_config(model_name)
121
+ model = LlavaLlamaForCausalLM.from_pretrained(
122
+ model_name, torch_dtype=torch.float16, use_cache=True
123
+ ).cuda()
124
+ image_processor = CLIPImageProcessor.from_pretrained(
125
+ model.config.mm_vision_tower, torch_dtype=torch.float16
126
+ )
127
 
128
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
129
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
130
  if mm_use_im_start_end:
131
+ tokenizer.add_tokens(
132
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
133
+ )
134
 
135
  vision_tower = model.model.vision_tower[0]
136
+ vision_tower.to(device="cuda", dtype=torch.float16)
137
  vision_config = vision_tower.config
138
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
139
+ [DEFAULT_IMAGE_PATCH_TOKEN]
140
+ )[0]
141
  vision_config.use_im_start_end = mm_use_im_start_end
142
  if mm_use_im_start_end:
143
+ (
144
+ vision_config.im_start_token,
145
+ vision_config.im_end_token,
146
+ ) = tokenizer.convert_tokens_to_ids(
147
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
148
+ )
149
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
150
  else:
151
  # in case of using a pretrained model with only a MLP projector weights
152
+ model = LlavaLlamaForCausalLM.from_pretrained(
153
+ model_name, torch_dtype=torch.float16, use_cache=True
154
+ ).cuda()
155
 
156
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
157
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
158
  if mm_use_im_start_end:
159
+ tokenizer.add_tokens(
160
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
161
+ )
162
 
163
+ vision_tower = CLIPVisionModel.from_pretrained(
164
+ args.vision_tower, torch_dtype=torch.float16
165
+ ).cuda()
166
+ image_processor = CLIPImageProcessor.from_pretrained(
167
+ args.vision_tower, torch_dtype=torch.float16
168
+ )
169
 
170
  vision_config = vision_tower.config
171
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
172
+ [DEFAULT_IMAGE_PATCH_TOKEN]
173
+ )[0]
174
  vision_config.use_im_start_end = mm_use_im_start_end
175
  if mm_use_im_start_end:
176
+ (
177
+ vision_config.im_start_token,
178
+ vision_config.im_end_token,
179
+ ) = tokenizer.convert_tokens_to_ids(
180
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
181
+ )
182
 
183
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
184
 
185
+ mm_projector = torch.nn.Linear(
186
+ vision_config.hidden_size, model.config.hidden_size
187
+ )
188
+ mm_projector_weights = torch.load(args.mm_projector, map_location="cpu")
189
+ mm_projector.load_state_dict(
190
+ {k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
191
+ )
192
 
193
  model.model.mm_projector = mm_projector.cuda().half()
194
  model.model.vision_tower = [vision_tower]
 
199
  os.makedirs(os.path.dirname(answers_file), exist_ok=True)
200
  os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True)
201
  ans_file = open(answers_file, "w")
202
+ save_image_folder = os.path.join(
203
+ os.path.dirname(os.path.expanduser(args.answers_file)), "images"
204
+ )
205
  for i, line in enumerate(tqdm(questions)):
206
  idx = line["id"]
207
+ question = line["conversations"][0]
208
  gt_ans = line["conversations"][1]
 
 
209
 
210
+ qs = question["value"]
211
+
212
+ qs = qs.replace("<image>", "").strip()
213
  cur_prompt = qs
214
 
215
+ if "image" in line:
216
  image_file = line["image"]
217
  image = Image.open(os.path.join(args.image_folder, image_file))
218
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
219
+ "pixel_values"
220
+ ][0]
221
  images = image_tensor.unsqueeze(0).half().cuda()
222
+ if getattr(model.config, "mm_use_im_start_end", False):
223
+ qs = (
224
+ qs
225
+ + "\n"
226
+ + DEFAULT_IM_START_TOKEN
227
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
228
+ + DEFAULT_IM_END_TOKEN
229
+ )
230
  else:
231
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
232
+ cur_prompt = cur_prompt + "\n" + "<image>"
233
  else:
234
  images = None
235
 
236
+ if args.conv_mode == "simple_legacy":
237
+ qs += "\n\n### Response:"
238
+ assert gt_ans["from"] == "gpt"
239
  # conv = default_conversation.copy()
240
  conv = conv_templates[args.conv_mode].copy()
241
  conv.append_message(conv.roles[0], qs)
 
244
 
245
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
246
 
247
+ keywords = ["###"]
248
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
249
 
250
  with torch.inference_mode():
 
254
  do_sample=True,
255
  temperature=0.7,
256
  max_new_tokens=1024,
257
+ stopping_criteria=[stopping_criteria],
258
+ )
259
 
260
  # TODO: new implementation
261
  input_token_len = input_ids.shape[1]
262
+ n_diff_input_output = (
263
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
264
+ )
265
  if n_diff_input_output > 0:
266
+ print(
267
+ f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
268
+ )
269
+ outputs = tokenizer.batch_decode(
270
+ output_ids[:, input_token_len:], skip_special_tokens=True
271
+ )[0]
272
+
273
+ if args.conv_mode == "simple_legacy":
274
  while True:
275
  cur_len = len(outputs)
276
  outputs = outputs.strip()
277
+ for pattern in ["###", "Assistant:", "Response:"]:
278
  if outputs.startswith(pattern):
279
+ outputs = outputs[len(pattern) :].strip()
280
  if len(outputs) == cur_len:
281
  break
282
 
 
291
  # prompt for answer
292
  if args.answer_prompter:
293
  outputs_reasoning = outputs
294
+ inputs = tokenizer([prompt + outputs_reasoning + " ###\nANSWER:"])
295
 
296
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
297
 
298
+ keywords = ["###"]
299
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
300
 
301
  with torch.inference_mode():
 
305
  do_sample=True,
306
  temperature=0.7,
307
  max_new_tokens=64,
308
+ stopping_criteria=[stopping_criteria],
309
+ )
310
 
311
  input_token_len = input_ids.shape[1]
312
+ n_diff_input_output = (
313
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
314
+ )
315
  if n_diff_input_output > 0:
316
+ print(
317
+ f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids"
318
+ )
319
+ outputs = tokenizer.batch_decode(
320
+ output_ids[:, input_token_len:], skip_special_tokens=True
321
+ )[0]
322
 
323
  try:
324
  index = outputs.index(conv.sep)
 
327
  index = outputs.index(conv.sep)
328
 
329
  outputs = outputs[:index].strip()
330
+ outputs = outputs_reasoning + "\n The answer is " + outputs
331
 
332
  # new implementation ends
333
 
 
341
 
342
  # outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
343
 
 
344
  ans_id = shortuuid.uuid()
345
+ ans_file.write(
346
+ json.dumps(
347
+ {
348
+ "question_id": idx,
349
+ "prompt": cur_prompt,
350
+ "text": outputs,
351
+ "answer_id": ans_id,
352
+ "model_id": model_name,
353
+ "metadata": {},
354
+ }
355
+ )
356
+ + "\n"
357
+ )
358
  ans_file.flush()
359
  ans_file.close()
360
 
361
+
362
  if __name__ == "__main__":
363
  parser = argparse.ArgumentParser()
364
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
model/llava/eval/qa_baseline_gpt35.py CHANGED
@@ -1,51 +1,57 @@
1
  """Generate answers with GPT-3.5"""
2
  # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3
  import argparse
 
4
  import json
5
  import os
6
  import time
7
- import concurrent.futures
8
 
9
  import openai
10
- import tqdm
11
  import shortuuid
 
 
 
 
12
 
13
- MODEL = 'gpt-3.5-turbo'
14
- MODEL_ID = 'gpt-3.5-turbo:20230327'
15
 
16
  def get_answer(question_id: int, question: str, max_tokens: int):
17
  ans = {
18
- 'answer_id': shortuuid.uuid(),
19
- 'question_id': question_id,
20
- 'model_id': MODEL_ID,
21
  }
22
  for _ in range(3):
23
  try:
24
  response = openai.ChatCompletion.create(
25
  model=MODEL,
26
- messages=[{
27
- 'role': 'system',
28
- 'content': 'You are a helpful assistant.'
29
- }, {
30
- 'role': 'user',
31
- 'content': question,
32
- }],
33
  max_tokens=max_tokens,
34
  )
35
- ans['text'] = response['choices'][0]['message']['content']
36
  return ans
37
  except Exception as e:
38
- print('[ERROR]', e)
39
- ans['text'] = '#ERROR#'
40
  time.sleep(1)
41
  return ans
42
 
43
 
44
- if __name__ == '__main__':
45
- parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46
- parser.add_argument('-q', '--question')
47
- parser.add_argument('-o', '--output')
48
- parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
 
 
 
 
 
49
  args = parser.parse_args()
50
 
51
  questions_dict = {}
@@ -54,7 +60,7 @@ if __name__ == '__main__':
54
  if not line:
55
  continue
56
  q = json.loads(line)
57
- questions_dict[q['question_id']] = q['text']
58
 
59
  answers = []
60
 
@@ -64,11 +70,13 @@ if __name__ == '__main__':
64
  future = executor.submit(get_answer, qid, question, args.max_tokens)
65
  futures.append(future)
66
 
67
- for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
 
 
68
  answers.append(future.result())
69
 
70
- answers.sort(key=lambda x: x['question_id'])
71
 
72
- with open(os.path.expanduser(args.output), 'w') as f:
73
  table = [json.dumps(ans) for ans in answers]
74
- f.write('\n'.join(table))
 
1
  """Generate answers with GPT-3.5"""
2
  # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3
  import argparse
4
+ import concurrent.futures
5
  import json
6
  import os
7
  import time
 
8
 
9
  import openai
 
10
  import shortuuid
11
+ import tqdm
12
+
13
+ MODEL = "gpt-3.5-turbo"
14
+ MODEL_ID = "gpt-3.5-turbo:20230327"
15
 
 
 
16
 
17
  def get_answer(question_id: int, question: str, max_tokens: int):
18
  ans = {
19
+ "answer_id": shortuuid.uuid(),
20
+ "question_id": question_id,
21
+ "model_id": MODEL_ID,
22
  }
23
  for _ in range(3):
24
  try:
25
  response = openai.ChatCompletion.create(
26
  model=MODEL,
27
+ messages=[
28
+ {"role": "system", "content": "You are a helpful assistant."},
29
+ {
30
+ "role": "user",
31
+ "content": question,
32
+ },
33
+ ],
34
  max_tokens=max_tokens,
35
  )
36
+ ans["text"] = response["choices"][0]["message"]["content"]
37
  return ans
38
  except Exception as e:
39
+ print("[ERROR]", e)
40
+ ans["text"] = "#ERROR#"
41
  time.sleep(1)
42
  return ans
43
 
44
 
45
+ if __name__ == "__main__":
46
+ parser = argparse.ArgumentParser(description="ChatGPT answer generation.")
47
+ parser.add_argument("-q", "--question")
48
+ parser.add_argument("-o", "--output")
49
+ parser.add_argument(
50
+ "--max-tokens",
51
+ type=int,
52
+ default=1024,
53
+ help="maximum number of tokens produced in the output",
54
+ )
55
  args = parser.parse_args()
56
 
57
  questions_dict = {}
 
60
  if not line:
61
  continue
62
  q = json.loads(line)
63
+ questions_dict[q["question_id"]] = q["text"]
64
 
65
  answers = []
66
 
 
70
  future = executor.submit(get_answer, qid, question, args.max_tokens)
71
  futures.append(future)
72
 
73
+ for future in tqdm.tqdm(
74
+ concurrent.futures.as_completed(futures), total=len(futures)
75
+ ):
76
  answers.append(future.result())
77
 
78
+ answers.sort(key=lambda x: x["question_id"])
79
 
80
+ with open(os.path.expanduser(args.output), "w") as f:
81
  table = [json.dumps(ans) for ans in answers]
82
+ f.write("\n".join(table))
model/llava/eval/run_llava.py CHANGED
@@ -1,20 +1,17 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
  import os
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.utils import disable_torch_init
7
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8
- from llava.model import *
9
- from llava.model.utils import KeywordsStoppingCriteria
10
-
11
- from PIL import Image
12
 
13
- import os
14
  import requests
 
 
 
 
 
15
  from PIL import Image
16
- from io import BytesIO
17
-
 
18
 
19
  DEFAULT_IMAGE_TOKEN = "<image>"
20
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -23,11 +20,11 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
23
 
24
 
25
  def load_image(image_file):
26
- if image_file.startswith('http') or image_file.startswith('https'):
27
  response = requests.get(image_file)
28
- image = Image.open(BytesIO(response.content)).convert('RGB')
29
  else:
30
- image = Image.open(image_file).convert('RGB')
31
  return image
32
 
33
 
@@ -38,35 +35,63 @@ def eval_model(args):
38
  tokenizer = AutoTokenizer.from_pretrained(model_name)
39
 
40
  if "mpt" in model_name.lower():
41
- model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
 
 
 
 
 
42
  else:
43
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
44
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda()
45
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
46
 
47
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
48
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
49
  if mm_use_im_start_end:
50
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
51
 
52
  vision_tower = model.get_model().vision_tower[0]
53
- if vision_tower.device.type == 'meta':
54
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
 
 
 
 
55
  model.get_model().vision_tower[0] = vision_tower
56
  else:
57
- vision_tower.to(device='cuda', dtype=torch.float16)
58
  vision_config = vision_tower.config
59
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
60
  vision_config.use_im_start_end = mm_use_im_start_end
61
  if mm_use_im_start_end:
62
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
63
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
64
 
65
  qs = args.query
66
  if mm_use_im_start_end:
67
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
68
  else:
69
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
70
 
71
  if "v1" in model_name.lower():
72
  conv_mode = "llava_v1"
@@ -76,7 +101,11 @@ def eval_model(args):
76
  conv_mode = "multimodal"
77
 
78
  if args.conv_mode is not None and conv_mode != args.conv_mode:
79
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
 
 
 
 
80
  else:
81
  args.conv_mode = conv_mode
82
 
@@ -87,7 +116,9 @@ def eval_model(args):
87
  inputs = tokenizer([prompt])
88
 
89
  image = load_image(args.image_file)
90
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
91
 
92
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
93
 
@@ -96,20 +127,34 @@ def eval_model(args):
96
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
97
 
98
  with torch.inference_mode():
99
- output_ids = model.generate(input_ids, images=image_tensor.unsqueeze(0).half().cuda(), do_sample=True, temperature=0.2, max_new_tokens=1024, stopping_criteria=[stopping_criteria])
 
 
 
 
 
 
 
100
 
101
  input_token_len = input_ids.shape[1]
102
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
103
  if n_diff_input_output > 0:
104
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
105
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
 
 
 
 
106
  outputs = outputs.strip()
107
  if outputs.endswith(stop_str):
108
- outputs = outputs[:-len(stop_str)]
109
  outputs = outputs.strip()
110
  print(outputs)
111
 
112
- import pdb; pdb.set_trace()
 
 
 
113
 
114
  if __name__ == "__main__":
115
  parser = argparse.ArgumentParser()
 
1
  import argparse
 
 
2
  import os
3
+ from io import BytesIO
 
 
 
 
 
 
4
 
 
5
  import requests
6
+ import torch
7
+ from llava.conversation import SeparatorStyle, conv_templates
8
+ from llava.model import *
9
+ from llava.model.utils import KeywordsStoppingCriteria
10
+ from llava.utils import disable_torch_init
11
  from PIL import Image
12
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
13
+ CLIPImageProcessor, CLIPVisionModel,
14
+ StoppingCriteria)
15
 
16
  DEFAULT_IMAGE_TOKEN = "<image>"
17
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
20
 
21
 
22
  def load_image(image_file):
23
+ if image_file.startswith("http") or image_file.startswith("https"):
24
  response = requests.get(image_file)
25
+ image = Image.open(BytesIO(response.content)).convert("RGB")
26
  else:
27
+ image = Image.open(image_file).convert("RGB")
28
  return image
29
 
30
 
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_name)
36
 
37
  if "mpt" in model_name.lower():
38
+ model = LlavaMPTForCausalLM.from_pretrained(
39
+ model_name,
40
+ low_cpu_mem_usage=True,
41
+ torch_dtype=torch.float16,
42
+ use_cache=True,
43
+ ).cuda()
44
  else:
45
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
46
+ model = LlavaLlamaForCausalLM.from_pretrained(
47
+ model_name, torch_dtype=torch.float16, device_map="auto"
48
+ ) # .cuda()
49
+ image_processor = CLIPImageProcessor.from_pretrained(
50
+ model.config.mm_vision_tower, torch_dtype=torch.float16
51
+ )
52
 
53
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
54
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
55
  if mm_use_im_start_end:
56
+ tokenizer.add_tokens(
57
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
58
+ )
59
 
60
  vision_tower = model.get_model().vision_tower[0]
61
+ if vision_tower.device.type == "meta":
62
+ vision_tower = CLIPVisionModel.from_pretrained(
63
+ vision_tower.config._name_or_path,
64
+ torch_dtype=torch.float16,
65
+ low_cpu_mem_usage=True,
66
+ ).cuda()
67
  model.get_model().vision_tower[0] = vision_tower
68
  else:
69
+ vision_tower.to(device="cuda", dtype=torch.float16)
70
  vision_config = vision_tower.config
71
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
72
+ [DEFAULT_IMAGE_PATCH_TOKEN]
73
+ )[0]
74
  vision_config.use_im_start_end = mm_use_im_start_end
75
  if mm_use_im_start_end:
76
+ (
77
+ vision_config.im_start_token,
78
+ vision_config.im_end_token,
79
+ ) = tokenizer.convert_tokens_to_ids(
80
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
81
+ )
82
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
83
 
84
  qs = args.query
85
  if mm_use_im_start_end:
86
+ qs = (
87
+ qs
88
+ + "\n"
89
+ + DEFAULT_IM_START_TOKEN
90
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
91
+ + DEFAULT_IM_END_TOKEN
92
+ )
93
  else:
94
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
95
 
96
  if "v1" in model_name.lower():
97
  conv_mode = "llava_v1"
 
101
  conv_mode = "multimodal"
102
 
103
  if args.conv_mode is not None and conv_mode != args.conv_mode:
104
+ print(
105
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
106
+ conv_mode, args.conv_mode, args.conv_mode
107
+ )
108
+ )
109
  else:
110
  args.conv_mode = conv_mode
111
 
 
116
  inputs = tokenizer([prompt])
117
 
118
  image = load_image(args.image_file)
119
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
120
+ "pixel_values"
121
+ ][0]
122
 
123
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
124
 
 
127
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
128
 
129
  with torch.inference_mode():
130
+ output_ids = model.generate(
131
+ input_ids,
132
+ images=image_tensor.unsqueeze(0).half().cuda(),
133
+ do_sample=True,
134
+ temperature=0.2,
135
+ max_new_tokens=1024,
136
+ stopping_criteria=[stopping_criteria],
137
+ )
138
 
139
  input_token_len = input_ids.shape[1]
140
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
141
  if n_diff_input_output > 0:
142
+ print(
143
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
144
+ )
145
+ outputs = tokenizer.batch_decode(
146
+ output_ids[:, input_token_len:], skip_special_tokens=True
147
+ )[0]
148
  outputs = outputs.strip()
149
  if outputs.endswith(stop_str):
150
+ outputs = outputs[: -len(stop_str)]
151
  outputs = outputs.strip()
152
  print(outputs)
153
 
154
+ import pdb
155
+
156
+ pdb.set_trace()
157
+
158
 
159
  if __name__ == "__main__":
160
  parser = argparse.ArgumentParser()
model/llava/eval/run_llava_batch.py CHANGED
@@ -1,24 +1,21 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- import os
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.utils import disable_torch_init
7
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8
- from llava.model import *
9
- from llava.model.utils import KeywordsStoppingCriteria
10
-
11
- from PIL import Image
12
-
13
  import os
14
- import requests
15
- from PIL import Image
16
  from io import BytesIO
17
 
18
- import glob
19
  import numpy as np
20
- import json
 
21
  import tqdm
 
 
 
 
 
 
 
 
22
 
23
  DEFAULT_IMAGE_TOKEN = "<image>"
24
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
27
 
28
 
29
  def load_image(image_file):
30
- if image_file.startswith('http') or image_file.startswith('https'):
31
  response = requests.get(image_file)
32
- image = Image.open(BytesIO(response.content)).convert('RGB')
33
  else:
34
- image = Image.open(image_file).convert('RGB')
35
  return image
36
 
37
 
38
- classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
39
- 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk',
40
- 'person', 'earth', 'door', 'table', 'mountain', 'plant',
41
- 'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
42
- 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
43
- 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
44
- 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
45
- 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
46
- 'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
47
- 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
48
- 'screen door', 'stairway', 'river', 'bridge', 'bookcase',
49
- 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
50
- 'bench', 'countertop', 'stove', 'palm', 'kitchen island',
51
- 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
52
- 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
53
- 'chandelier', 'awning', 'streetlight', 'booth',
54
- 'television receiver', 'airplane', 'dirt track', 'apparel',
55
- 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
56
- 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
57
- 'conveyer belt', 'canopy', 'washer', 'plaything',
58
- 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
59
- 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
60
- 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
61
- 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
62
- 'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
63
- 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
64
- 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
65
- 'clock', 'flag']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def eval_model(args):
68
  # Model
@@ -71,35 +193,58 @@ def eval_model(args):
71
  tokenizer = AutoTokenizer.from_pretrained(model_name)
72
 
73
  if "mpt" in model_name.lower():
74
- model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
 
 
 
 
 
75
  else:
76
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
77
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda()
78
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
79
 
80
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
81
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
82
  if mm_use_im_start_end:
83
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
84
 
85
  vision_tower = model.get_model().vision_tower[0]
86
- if vision_tower.device.type == 'meta':
87
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
 
 
 
 
88
  model.get_model().vision_tower[0] = vision_tower
89
  else:
90
- vision_tower.to(device='cuda', dtype=torch.float16)
91
  vision_config = vision_tower.config
92
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
93
  vision_config.use_im_start_end = mm_use_im_start_end
94
  if mm_use_im_start_end:
95
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
96
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
97
 
98
  # paths for all images
99
- images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg"))
 
 
100
  results = []
101
  for i, image_file in enumerate(tqdm.tqdm(images)):
102
-
103
  # if i == 2:
104
  # break
105
 
@@ -109,7 +254,9 @@ def eval_model(args):
109
  print("i: {}, len(images): {}".format(i, len(images)))
110
 
111
  image = load_image(image_file)
112
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
113
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
114
 
115
  label_file = image_file.replace("images", "annotations").replace(".jpg", ".png")
@@ -126,9 +273,15 @@ def eval_model(args):
126
 
127
  # qs = args.query
128
  if mm_use_im_start_end:
129
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
130
  else:
131
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
132
 
133
  if "v1" in model_name.lower():
134
  conv_mode = "llava_v1"
@@ -138,7 +291,11 @@ def eval_model(args):
138
  conv_mode = "multimodal"
139
 
140
  if args.conv_mode is not None and conv_mode != args.conv_mode:
141
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
 
 
 
 
142
  else:
143
  args.conv_mode = conv_mode
144
 
@@ -164,27 +321,41 @@ def eval_model(args):
164
  images=image_tensor,
165
  do_sample=True,
166
  temperature=0.2,
167
- max_new_tokens=512, #1024,
168
- stopping_criteria=[stopping_criteria])
 
169
 
170
  input_token_len = input_ids.shape[1]
171
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
172
  if n_diff_input_output > 0:
173
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
174
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
 
 
 
 
175
  outputs = outputs.strip()
176
  if outputs.endswith(stop_str):
177
- outputs = outputs[:-len(stop_str)]
178
  outputs = outputs.strip()
179
 
180
  print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
181
 
182
- results.append({'image_id': image_file.split("/")[-1], 'input': input_conv, 'output': outputs})
 
 
 
 
 
 
183
 
184
  with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
185
  json.dump(results, f)
186
 
187
- # print(outputs)
 
188
 
189
  if __name__ == "__main__":
190
  parser = argparse.ArgumentParser()
 
1
  import argparse
2
+ import glob
3
+ import json
 
 
 
 
 
 
 
 
 
4
  import os
 
 
5
  from io import BytesIO
6
 
 
7
  import numpy as np
8
+ import requests
9
+ import torch
10
  import tqdm
11
+ from llava.conversation import SeparatorStyle, conv_templates
12
+ from llava.model import *
13
+ from llava.model.utils import KeywordsStoppingCriteria
14
+ from llava.utils import disable_torch_init
15
+ from PIL import Image
16
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
+ CLIPImageProcessor, CLIPVisionModel,
18
+ StoppingCriteria)
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
24
 
25
 
26
  def load_image(image_file):
27
+ if image_file.startswith("http") or image_file.startswith("https"):
28
  response = requests.get(image_file)
29
+ image = Image.open(BytesIO(response.content)).convert("RGB")
30
  else:
31
+ image = Image.open(image_file).convert("RGB")
32
  return image
33
 
34
 
35
+ classes = [
36
+ "wall",
37
+ "building",
38
+ "sky",
39
+ "floor",
40
+ "tree",
41
+ "ceiling",
42
+ "road",
43
+ "bed",
44
+ "windowpane",
45
+ "grass",
46
+ "cabinet",
47
+ "sidewalk",
48
+ "person",
49
+ "earth",
50
+ "door",
51
+ "table",
52
+ "mountain",
53
+ "plant",
54
+ "curtain",
55
+ "chair",
56
+ "car",
57
+ "water",
58
+ "painting",
59
+ "sofa",
60
+ "shelf",
61
+ "house",
62
+ "sea",
63
+ "mirror",
64
+ "rug",
65
+ "field",
66
+ "armchair",
67
+ "seat",
68
+ "fence",
69
+ "desk",
70
+ "rock",
71
+ "wardrobe",
72
+ "lamp",
73
+ "bathtub",
74
+ "railing",
75
+ "cushion",
76
+ "base",
77
+ "box",
78
+ "column",
79
+ "signboard",
80
+ "chest of drawers",
81
+ "counter",
82
+ "sand",
83
+ "sink",
84
+ "skyscraper",
85
+ "fireplace",
86
+ "refrigerator",
87
+ "grandstand",
88
+ "path",
89
+ "stairs",
90
+ "runway",
91
+ "case",
92
+ "pool table",
93
+ "pillow",
94
+ "screen door",
95
+ "stairway",
96
+ "river",
97
+ "bridge",
98
+ "bookcase",
99
+ "blind",
100
+ "coffee table",
101
+ "toilet",
102
+ "flower",
103
+ "book",
104
+ "hill",
105
+ "bench",
106
+ "countertop",
107
+ "stove",
108
+ "palm",
109
+ "kitchen island",
110
+ "computer",
111
+ "swivel chair",
112
+ "boat",
113
+ "bar",
114
+ "arcade machine",
115
+ "hovel",
116
+ "bus",
117
+ "towel",
118
+ "light",
119
+ "truck",
120
+ "tower",
121
+ "chandelier",
122
+ "awning",
123
+ "streetlight",
124
+ "booth",
125
+ "television receiver",
126
+ "airplane",
127
+ "dirt track",
128
+ "apparel",
129
+ "pole",
130
+ "land",
131
+ "bannister",
132
+ "escalator",
133
+ "ottoman",
134
+ "bottle",
135
+ "buffet",
136
+ "poster",
137
+ "stage",
138
+ "van",
139
+ "ship",
140
+ "fountain",
141
+ "conveyer belt",
142
+ "canopy",
143
+ "washer",
144
+ "plaything",
145
+ "swimming pool",
146
+ "stool",
147
+ "barrel",
148
+ "basket",
149
+ "waterfall",
150
+ "tent",
151
+ "bag",
152
+ "minibike",
153
+ "cradle",
154
+ "oven",
155
+ "ball",
156
+ "food",
157
+ "step",
158
+ "tank",
159
+ "trade name",
160
+ "microwave",
161
+ "pot",
162
+ "animal",
163
+ "bicycle",
164
+ "lake",
165
+ "dishwasher",
166
+ "screen",
167
+ "blanket",
168
+ "sculpture",
169
+ "hood",
170
+ "sconce",
171
+ "vase",
172
+ "traffic light",
173
+ "tray",
174
+ "ashcan",
175
+ "fan",
176
+ "pier",
177
+ "crt screen",
178
+ "plate",
179
+ "monitor",
180
+ "bulletin board",
181
+ "shower",
182
+ "radiator",
183
+ "glass",
184
+ "clock",
185
+ "flag",
186
+ ]
187
+
188
 
189
  def eval_model(args):
190
  # Model
 
193
  tokenizer = AutoTokenizer.from_pretrained(model_name)
194
 
195
  if "mpt" in model_name.lower():
196
+ model = LlavaMPTForCausalLM.from_pretrained(
197
+ model_name,
198
+ low_cpu_mem_usage=True,
199
+ torch_dtype=torch.float16,
200
+ use_cache=True,
201
+ ).cuda()
202
  else:
203
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
204
+ model = LlavaLlamaForCausalLM.from_pretrained(
205
+ model_name, torch_dtype=torch.float16, device_map="auto"
206
+ ) # .cuda()
207
+ image_processor = CLIPImageProcessor.from_pretrained(
208
+ model.config.mm_vision_tower, torch_dtype=torch.float16
209
+ )
210
 
211
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
212
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
213
  if mm_use_im_start_end:
214
+ tokenizer.add_tokens(
215
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
216
+ )
217
 
218
  vision_tower = model.get_model().vision_tower[0]
219
+ if vision_tower.device.type == "meta":
220
+ vision_tower = CLIPVisionModel.from_pretrained(
221
+ vision_tower.config._name_or_path,
222
+ torch_dtype=torch.float16,
223
+ low_cpu_mem_usage=True,
224
+ ).cuda()
225
  model.get_model().vision_tower[0] = vision_tower
226
  else:
227
+ vision_tower.to(device="cuda", dtype=torch.float16)
228
  vision_config = vision_tower.config
229
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
230
+ [DEFAULT_IMAGE_PATCH_TOKEN]
231
+ )[0]
232
  vision_config.use_im_start_end = mm_use_im_start_end
233
  if mm_use_im_start_end:
234
+ (
235
+ vision_config.im_start_token,
236
+ vision_config.im_end_token,
237
+ ) = tokenizer.convert_tokens_to_ids(
238
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
239
+ )
240
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
241
 
242
  # paths for all images
243
+ images = sorted(
244
+ glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
245
+ )
246
  results = []
247
  for i, image_file in enumerate(tqdm.tqdm(images)):
 
248
  # if i == 2:
249
  # break
250
 
 
254
  print("i: {}, len(images): {}".format(i, len(images)))
255
 
256
  image = load_image(image_file)
257
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
258
+ "pixel_values"
259
+ ][0]
260
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
261
 
262
  label_file = image_file.replace("images", "annotations").replace(".jpg", ".png")
 
273
 
274
  # qs = args.query
275
  if mm_use_im_start_end:
276
+ qs = (
277
+ qs
278
+ + "\n"
279
+ + DEFAULT_IM_START_TOKEN
280
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
281
+ + DEFAULT_IM_END_TOKEN
282
+ )
283
  else:
284
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
285
 
286
  if "v1" in model_name.lower():
287
  conv_mode = "llava_v1"
 
291
  conv_mode = "multimodal"
292
 
293
  if args.conv_mode is not None and conv_mode != args.conv_mode:
294
+ print(
295
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
296
+ conv_mode, args.conv_mode, args.conv_mode
297
+ )
298
+ )
299
  else:
300
  args.conv_mode = conv_mode
301
 
 
321
  images=image_tensor,
322
  do_sample=True,
323
  temperature=0.2,
324
+ max_new_tokens=512, # 1024,
325
+ stopping_criteria=[stopping_criteria],
326
+ )
327
 
328
  input_token_len = input_ids.shape[1]
329
+ n_diff_input_output = (
330
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
331
+ )
332
  if n_diff_input_output > 0:
333
+ print(
334
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
335
+ )
336
+ outputs = tokenizer.batch_decode(
337
+ output_ids[:, input_token_len:], skip_special_tokens=True
338
+ )[0]
339
  outputs = outputs.strip()
340
  if outputs.endswith(stop_str):
341
+ outputs = outputs[: -len(stop_str)]
342
  outputs = outputs.strip()
343
 
344
  print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
345
 
346
+ results.append(
347
+ {
348
+ "image_id": image_file.split("/")[-1],
349
+ "input": input_conv,
350
+ "output": outputs,
351
+ }
352
+ )
353
 
354
  with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
355
  json.dump(results, f)
356
 
357
+ # print(outputs)
358
+
359
 
360
  if __name__ == "__main__":
361
  parser = argparse.ArgumentParser()
model/llava/eval/run_llava_batch_v2.py CHANGED
@@ -1,24 +1,21 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- import os
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.utils import disable_torch_init
7
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8
- from llava.model import *
9
- from llava.model.utils import KeywordsStoppingCriteria
10
-
11
- from PIL import Image
12
-
13
  import os
14
- import requests
15
- from PIL import Image
16
  from io import BytesIO
17
 
18
- import glob
19
  import numpy as np
20
- import json
 
21
  import tqdm
 
 
 
 
 
 
 
 
22
 
23
  DEFAULT_IMAGE_TOKEN = "<image>"
24
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
27
 
28
 
29
  def load_image(image_file):
30
- if image_file.startswith('http') or image_file.startswith('https'):
31
  response = requests.get(image_file)
32
- image = Image.open(BytesIO(response.content)).convert('RGB')
33
  else:
34
- image = Image.open(image_file).convert('RGB')
35
  return image
36
 
37
 
38
- classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
39
- 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk',
40
- 'person', 'earth', 'door', 'table', 'mountain', 'plant',
41
- 'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
42
- 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
43
- 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
44
- 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
45
- 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
46
- 'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
47
- 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
48
- 'screen door', 'stairway', 'river', 'bridge', 'bookcase',
49
- 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
50
- 'bench', 'countertop', 'stove', 'palm', 'kitchen island',
51
- 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
52
- 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
53
- 'chandelier', 'awning', 'streetlight', 'booth',
54
- 'television receiver', 'airplane', 'dirt track', 'apparel',
55
- 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
56
- 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
57
- 'conveyer belt', 'canopy', 'washer', 'plaything',
58
- 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
59
- 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
60
- 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
61
- 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
62
- 'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
63
- 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
64
- 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
65
- 'clock', 'flag']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def eval_model(args):
68
  # Model
@@ -71,35 +193,58 @@ def eval_model(args):
71
  tokenizer = AutoTokenizer.from_pretrained(model_name)
72
 
73
  if "mpt" in model_name.lower():
74
- model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
 
 
 
 
 
75
  else:
76
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
77
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda()
78
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
79
 
80
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
81
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
82
  if mm_use_im_start_end:
83
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
84
 
85
  vision_tower = model.get_model().vision_tower[0]
86
- if vision_tower.device.type == 'meta':
87
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
 
 
 
 
88
  model.get_model().vision_tower[0] = vision_tower
89
  # else:
90
- # vision_tower.to(device='cuda', dtype=torch.float16)
91
  vision_config = vision_tower.config
92
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
93
  vision_config.use_im_start_end = mm_use_im_start_end
94
  if mm_use_im_start_end:
95
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
96
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
97
 
98
  # paths for all images
99
- images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg"))
 
 
100
  results = []
101
  for i, image_file in enumerate(tqdm.tqdm(images)):
102
-
103
  # if i == 2:
104
  # break
105
 
@@ -115,7 +260,9 @@ def eval_model(args):
115
  label_unique = np.unique(label)
116
 
117
  image = load_image(image_file)
118
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
119
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
120
 
121
  for label in label_unique:
@@ -128,9 +275,15 @@ def eval_model(args):
128
 
129
  # qs = args.query
130
  if mm_use_im_start_end:
131
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
132
  else:
133
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
134
 
135
  if "v1" in model_name.lower():
136
  conv_mode = "llava_v1"
@@ -140,7 +293,11 @@ def eval_model(args):
140
  conv_mode = "multimodal"
141
 
142
  if args.conv_mode is not None and conv_mode != args.conv_mode:
143
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
 
 
 
 
144
  else:
145
  args.conv_mode = conv_mode
146
 
@@ -173,32 +330,46 @@ def eval_model(args):
173
  images=image_tensor,
174
  # do_sample=True,
175
  # temperature=0.2,
176
- max_new_tokens=512, #1024,
177
- stopping_criteria=[stopping_criteria])
 
178
 
179
  input_token_len = input_ids.shape[1]
180
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
181
  if n_diff_input_output > 0:
182
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
 
 
183
 
184
  outputs_list = []
185
  for output_id in output_ids:
186
- outputs = tokenizer.batch_decode(output_id[:, input_token_len:], skip_special_tokens=True)[0]
 
 
187
  outputs = outputs.strip()
188
  if outputs.endswith(stop_str):
189
- outputs = outputs[:-len(stop_str)]
190
  outputs = outputs.strip()
191
  outputs_list.append(outputs)
192
 
193
  for qs, outputs in zip(prompt_list, outputs_list):
194
  print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
195
 
196
- results.append({'image_id': image_file.split("/")[-1], 'input': prompt_list, 'output': outputs_list})
 
 
 
 
 
 
197
 
198
  with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
199
  json.dump(results, f)
200
 
201
- # print(outputs)
 
202
 
203
  if __name__ == "__main__":
204
  parser = argparse.ArgumentParser()
 
1
  import argparse
2
+ import glob
3
+ import json
 
 
 
 
 
 
 
 
 
4
  import os
 
 
5
  from io import BytesIO
6
 
 
7
  import numpy as np
8
+ import requests
9
+ import torch
10
  import tqdm
11
+ from llava.conversation import SeparatorStyle, conv_templates
12
+ from llava.model import *
13
+ from llava.model.utils import KeywordsStoppingCriteria
14
+ from llava.utils import disable_torch_init
15
+ from PIL import Image
16
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
+ CLIPImageProcessor, CLIPVisionModel,
18
+ StoppingCriteria)
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
24
 
25
 
26
  def load_image(image_file):
27
+ if image_file.startswith("http") or image_file.startswith("https"):
28
  response = requests.get(image_file)
29
+ image = Image.open(BytesIO(response.content)).convert("RGB")
30
  else:
31
+ image = Image.open(image_file).convert("RGB")
32
  return image
33
 
34
 
35
+ classes = [
36
+ "wall",
37
+ "building",
38
+ "sky",
39
+ "floor",
40
+ "tree",
41
+ "ceiling",
42
+ "road",
43
+ "bed",
44
+ "windowpane",
45
+ "grass",
46
+ "cabinet",
47
+ "sidewalk",
48
+ "person",
49
+ "earth",
50
+ "door",
51
+ "table",
52
+ "mountain",
53
+ "plant",
54
+ "curtain",
55
+ "chair",
56
+ "car",
57
+ "water",
58
+ "painting",
59
+ "sofa",
60
+ "shelf",
61
+ "house",
62
+ "sea",
63
+ "mirror",
64
+ "rug",
65
+ "field",
66
+ "armchair",
67
+ "seat",
68
+ "fence",
69
+ "desk",
70
+ "rock",
71
+ "wardrobe",
72
+ "lamp",
73
+ "bathtub",
74
+ "railing",
75
+ "cushion",
76
+ "base",
77
+ "box",
78
+ "column",
79
+ "signboard",
80
+ "chest of drawers",
81
+ "counter",
82
+ "sand",
83
+ "sink",
84
+ "skyscraper",
85
+ "fireplace",
86
+ "refrigerator",
87
+ "grandstand",
88
+ "path",
89
+ "stairs",
90
+ "runway",
91
+ "case",
92
+ "pool table",
93
+ "pillow",
94
+ "screen door",
95
+ "stairway",
96
+ "river",
97
+ "bridge",
98
+ "bookcase",
99
+ "blind",
100
+ "coffee table",
101
+ "toilet",
102
+ "flower",
103
+ "book",
104
+ "hill",
105
+ "bench",
106
+ "countertop",
107
+ "stove",
108
+ "palm",
109
+ "kitchen island",
110
+ "computer",
111
+ "swivel chair",
112
+ "boat",
113
+ "bar",
114
+ "arcade machine",
115
+ "hovel",
116
+ "bus",
117
+ "towel",
118
+ "light",
119
+ "truck",
120
+ "tower",
121
+ "chandelier",
122
+ "awning",
123
+ "streetlight",
124
+ "booth",
125
+ "television receiver",
126
+ "airplane",
127
+ "dirt track",
128
+ "apparel",
129
+ "pole",
130
+ "land",
131
+ "bannister",
132
+ "escalator",
133
+ "ottoman",
134
+ "bottle",
135
+ "buffet",
136
+ "poster",
137
+ "stage",
138
+ "van",
139
+ "ship",
140
+ "fountain",
141
+ "conveyer belt",
142
+ "canopy",
143
+ "washer",
144
+ "plaything",
145
+ "swimming pool",
146
+ "stool",
147
+ "barrel",
148
+ "basket",
149
+ "waterfall",
150
+ "tent",
151
+ "bag",
152
+ "minibike",
153
+ "cradle",
154
+ "oven",
155
+ "ball",
156
+ "food",
157
+ "step",
158
+ "tank",
159
+ "trade name",
160
+ "microwave",
161
+ "pot",
162
+ "animal",
163
+ "bicycle",
164
+ "lake",
165
+ "dishwasher",
166
+ "screen",
167
+ "blanket",
168
+ "sculpture",
169
+ "hood",
170
+ "sconce",
171
+ "vase",
172
+ "traffic light",
173
+ "tray",
174
+ "ashcan",
175
+ "fan",
176
+ "pier",
177
+ "crt screen",
178
+ "plate",
179
+ "monitor",
180
+ "bulletin board",
181
+ "shower",
182
+ "radiator",
183
+ "glass",
184
+ "clock",
185
+ "flag",
186
+ ]
187
+
188
 
189
  def eval_model(args):
190
  # Model
 
193
  tokenizer = AutoTokenizer.from_pretrained(model_name)
194
 
195
  if "mpt" in model_name.lower():
196
+ model = LlavaMPTForCausalLM.from_pretrained(
197
+ model_name,
198
+ low_cpu_mem_usage=True,
199
+ torch_dtype=torch.float16,
200
+ use_cache=True,
201
+ ).cuda()
202
  else:
203
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
204
+ model = LlavaLlamaForCausalLM.from_pretrained(
205
+ model_name, torch_dtype=torch.float16, device_map="auto"
206
+ ) # .cuda()
207
+ image_processor = CLIPImageProcessor.from_pretrained(
208
+ model.config.mm_vision_tower, torch_dtype=torch.float16
209
+ )
210
 
211
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
212
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
213
  if mm_use_im_start_end:
214
+ tokenizer.add_tokens(
215
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
216
+ )
217
 
218
  vision_tower = model.get_model().vision_tower[0]
219
+ if vision_tower.device.type == "meta":
220
+ vision_tower = CLIPVisionModel.from_pretrained(
221
+ vision_tower.config._name_or_path,
222
+ torch_dtype=torch.float16,
223
+ low_cpu_mem_usage=True,
224
+ ).cuda()
225
  model.get_model().vision_tower[0] = vision_tower
226
  # else:
227
+ # vision_tower.to(device='cuda', dtype=torch.float16)
228
  vision_config = vision_tower.config
229
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
230
+ [DEFAULT_IMAGE_PATCH_TOKEN]
231
+ )[0]
232
  vision_config.use_im_start_end = mm_use_im_start_end
233
  if mm_use_im_start_end:
234
+ (
235
+ vision_config.im_start_token,
236
+ vision_config.im_end_token,
237
+ ) = tokenizer.convert_tokens_to_ids(
238
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
239
+ )
240
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
241
 
242
  # paths for all images
243
+ images = sorted(
244
+ glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
245
+ )
246
  results = []
247
  for i, image_file in enumerate(tqdm.tqdm(images)):
 
248
  # if i == 2:
249
  # break
250
 
 
260
  label_unique = np.unique(label)
261
 
262
  image = load_image(image_file)
263
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
264
+ "pixel_values"
265
+ ][0]
266
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
267
 
268
  for label in label_unique:
 
275
 
276
  # qs = args.query
277
  if mm_use_im_start_end:
278
+ qs = (
279
+ qs
280
+ + "\n"
281
+ + DEFAULT_IM_START_TOKEN
282
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
283
+ + DEFAULT_IM_END_TOKEN
284
+ )
285
  else:
286
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
287
 
288
  if "v1" in model_name.lower():
289
  conv_mode = "llava_v1"
 
293
  conv_mode = "multimodal"
294
 
295
  if args.conv_mode is not None and conv_mode != args.conv_mode:
296
+ print(
297
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
298
+ conv_mode, args.conv_mode, args.conv_mode
299
+ )
300
+ )
301
  else:
302
  args.conv_mode = conv_mode
303
 
 
330
  images=image_tensor,
331
  # do_sample=True,
332
  # temperature=0.2,
333
+ max_new_tokens=512, # 1024,
334
+ stopping_criteria=[stopping_criteria],
335
+ )
336
 
337
  input_token_len = input_ids.shape[1]
338
+ n_diff_input_output = (
339
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
340
+ )
341
  if n_diff_input_output > 0:
342
+ print(
343
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
344
+ )
345
 
346
  outputs_list = []
347
  for output_id in output_ids:
348
+ outputs = tokenizer.batch_decode(
349
+ output_id[:, input_token_len:], skip_special_tokens=True
350
+ )[0]
351
  outputs = outputs.strip()
352
  if outputs.endswith(stop_str):
353
+ outputs = outputs[: -len(stop_str)]
354
  outputs = outputs.strip()
355
  outputs_list.append(outputs)
356
 
357
  for qs, outputs in zip(prompt_list, outputs_list):
358
  print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
359
 
360
+ results.append(
361
+ {
362
+ "image_id": image_file.split("/")[-1],
363
+ "input": prompt_list,
364
+ "output": outputs_list,
365
+ }
366
+ )
367
 
368
  with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
369
  json.dump(results, f)
370
 
371
+ # print(outputs)
372
+
373
 
374
  if __name__ == "__main__":
375
  parser = argparse.ArgumentParser()
model/llava/eval/run_llava_batch_v3.py CHANGED
@@ -1,24 +1,21 @@
1
  import argparse
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- import os
5
- from llava.conversation import conv_templates, SeparatorStyle
6
- from llava.utils import disable_torch_init
7
- from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8
- from llava.model import *
9
- from llava.model.utils import KeywordsStoppingCriteria
10
-
11
- from PIL import Image
12
-
13
  import os
14
- import requests
15
- from PIL import Image
16
  from io import BytesIO
17
 
18
- import glob
19
  import numpy as np
20
- import json
 
21
  import tqdm
 
 
 
 
 
 
 
 
22
 
23
  DEFAULT_IMAGE_TOKEN = "<image>"
24
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
27
 
28
 
29
  def load_image(image_file):
30
- if image_file.startswith('http') or image_file.startswith('https'):
31
  response = requests.get(image_file)
32
- image = Image.open(BytesIO(response.content)).convert('RGB')
33
  else:
34
- image = Image.open(image_file).convert('RGB')
35
  return image
36
 
37
 
38
- classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
39
- 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk',
40
- 'person', 'earth', 'door', 'table', 'mountain', 'plant',
41
- 'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
42
- 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
43
- 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
44
- 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
45
- 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
46
- 'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
47
- 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
48
- 'screen door', 'stairway', 'river', 'bridge', 'bookcase',
49
- 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
50
- 'bench', 'countertop', 'stove', 'palm', 'kitchen island',
51
- 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
52
- 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
53
- 'chandelier', 'awning', 'streetlight', 'booth',
54
- 'television receiver', 'airplane', 'dirt track', 'apparel',
55
- 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
56
- 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
57
- 'conveyer belt', 'canopy', 'washer', 'plaything',
58
- 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
59
- 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
60
- 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
61
- 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
62
- 'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
63
- 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
64
- 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
65
- 'clock', 'flag']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def eval_model(args):
68
  # Model
@@ -71,38 +193,61 @@ def eval_model(args):
71
  tokenizer = AutoTokenizer.from_pretrained(model_name)
72
 
73
  if "mpt" in model_name.lower():
74
- model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
 
 
 
 
 
75
  else:
76
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
77
- model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda()
78
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
 
79
 
80
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
81
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
82
  if mm_use_im_start_end:
83
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
84
 
85
  vision_tower = model.get_model().vision_tower[0]
86
- if vision_tower.device.type == 'meta':
87
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
 
 
 
 
88
  model.get_model().vision_tower[0] = vision_tower
89
  else:
90
- vision_tower.to(device='cuda', dtype=torch.float16)
91
  vision_config = vision_tower.config
92
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
93
  vision_config.use_im_start_end = mm_use_im_start_end
94
  if mm_use_im_start_end:
95
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
96
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
97
 
98
  # paths for all images
99
- images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg"))
 
 
100
  start, end = args.range.split(",")
101
  start, end = int(start), int(end)
102
  images = images[start:end]
103
  results = []
104
  for i, image_file in enumerate(tqdm.tqdm(images)):
105
-
106
  # if i == 2:
107
  # break
108
 
@@ -112,7 +257,9 @@ def eval_model(args):
112
  print("i: {}, len(images): {}".format(i, len(images)))
113
 
114
  image = load_image(image_file)
115
- image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
116
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
117
 
118
  prompt_list = []
@@ -133,9 +280,15 @@ def eval_model(args):
133
 
134
  # qs = args.query
135
  if mm_use_im_start_end:
136
- qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
137
  else:
138
- qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
139
 
140
  if "v1" in model_name.lower():
141
  conv_mode = "llava_v1"
@@ -145,7 +298,11 @@ def eval_model(args):
145
  conv_mode = "multimodal"
146
 
147
  if args.conv_mode is not None and conv_mode != args.conv_mode:
148
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
 
 
 
 
149
  else:
150
  args.conv_mode = conv_mode
151
 
@@ -171,17 +328,24 @@ def eval_model(args):
171
  images=image_tensor,
172
  do_sample=True,
173
  temperature=0.2,
174
- max_new_tokens=512, #1024,
175
- stopping_criteria=[stopping_criteria])
 
176
 
177
  input_token_len = input_ids.shape[1]
178
- n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
 
 
179
  if n_diff_input_output > 0:
180
- print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
181
- outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
 
 
 
 
182
  outputs = outputs.strip()
183
  if outputs.endswith(stop_str):
184
- outputs = outputs[:-len(stop_str)]
185
  outputs = outputs.strip()
186
 
187
  # print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
@@ -190,13 +354,23 @@ def eval_model(args):
190
  # results.append({'image_id': image_id, 'input': input_conv, 'output': outputs})
191
  output_list.append(outputs)
192
  image_id = image_file.split("/")[-1].split(".")[0]
193
- with open("/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+") as f:
194
- json.dump({'image_id': image_id, 'input_list': prompt_list, 'output_list': output_list}, f)
 
 
 
 
 
 
 
 
 
195
 
196
  # with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
197
  # json.dump(results, f)
198
 
199
- # print(outputs)
 
200
 
201
  if __name__ == "__main__":
202
  parser = argparse.ArgumentParser()
 
1
  import argparse
2
+ import glob
3
+ import json
 
 
 
 
 
 
 
 
 
4
  import os
 
 
5
  from io import BytesIO
6
 
 
7
  import numpy as np
8
+ import requests
9
+ import torch
10
  import tqdm
11
+ from llava.conversation import SeparatorStyle, conv_templates
12
+ from llava.model import *
13
+ from llava.model.utils import KeywordsStoppingCriteria
14
+ from llava.utils import disable_torch_init
15
+ from PIL import Image
16
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
17
+ CLIPImageProcessor, CLIPVisionModel,
18
+ StoppingCriteria)
19
 
20
  DEFAULT_IMAGE_TOKEN = "<image>"
21
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
24
 
25
 
26
  def load_image(image_file):
27
+ if image_file.startswith("http") or image_file.startswith("https"):
28
  response = requests.get(image_file)
29
+ image = Image.open(BytesIO(response.content)).convert("RGB")
30
  else:
31
+ image = Image.open(image_file).convert("RGB")
32
  return image
33
 
34
 
35
+ classes = [
36
+ "wall",
37
+ "building",
38
+ "sky",
39
+ "floor",
40
+ "tree",
41
+ "ceiling",
42
+ "road",
43
+ "bed",
44
+ "windowpane",
45
+ "grass",
46
+ "cabinet",
47
+ "sidewalk",
48
+ "person",
49
+ "earth",
50
+ "door",
51
+ "table",
52
+ "mountain",
53
+ "plant",
54
+ "curtain",
55
+ "chair",
56
+ "car",
57
+ "water",
58
+ "painting",
59
+ "sofa",
60
+ "shelf",
61
+ "house",
62
+ "sea",
63
+ "mirror",
64
+ "rug",
65
+ "field",
66
+ "armchair",
67
+ "seat",
68
+ "fence",
69
+ "desk",
70
+ "rock",
71
+ "wardrobe",
72
+ "lamp",
73
+ "bathtub",
74
+ "railing",
75
+ "cushion",
76
+ "base",
77
+ "box",
78
+ "column",
79
+ "signboard",
80
+ "chest of drawers",
81
+ "counter",
82
+ "sand",
83
+ "sink",
84
+ "skyscraper",
85
+ "fireplace",
86
+ "refrigerator",
87
+ "grandstand",
88
+ "path",
89
+ "stairs",
90
+ "runway",
91
+ "case",
92
+ "pool table",
93
+ "pillow",
94
+ "screen door",
95
+ "stairway",
96
+ "river",
97
+ "bridge",
98
+ "bookcase",
99
+ "blind",
100
+ "coffee table",
101
+ "toilet",
102
+ "flower",
103
+ "book",
104
+ "hill",
105
+ "bench",
106
+ "countertop",
107
+ "stove",
108
+ "palm",
109
+ "kitchen island",
110
+ "computer",
111
+ "swivel chair",
112
+ "boat",
113
+ "bar",
114
+ "arcade machine",
115
+ "hovel",
116
+ "bus",
117
+ "towel",
118
+ "light",
119
+ "truck",
120
+ "tower",
121
+ "chandelier",
122
+ "awning",
123
+ "streetlight",
124
+ "booth",
125
+ "television receiver",
126
+ "airplane",
127
+ "dirt track",
128
+ "apparel",
129
+ "pole",
130
+ "land",
131
+ "bannister",
132
+ "escalator",
133
+ "ottoman",
134
+ "bottle",
135
+ "buffet",
136
+ "poster",
137
+ "stage",
138
+ "van",
139
+ "ship",
140
+ "fountain",
141
+ "conveyer belt",
142
+ "canopy",
143
+ "washer",
144
+ "plaything",
145
+ "swimming pool",
146
+ "stool",
147
+ "barrel",
148
+ "basket",
149
+ "waterfall",
150
+ "tent",
151
+ "bag",
152
+ "minibike",
153
+ "cradle",
154
+ "oven",
155
+ "ball",
156
+ "food",
157
+ "step",
158
+ "tank",
159
+ "trade name",
160
+ "microwave",
161
+ "pot",
162
+ "animal",
163
+ "bicycle",
164
+ "lake",
165
+ "dishwasher",
166
+ "screen",
167
+ "blanket",
168
+ "sculpture",
169
+ "hood",
170
+ "sconce",
171
+ "vase",
172
+ "traffic light",
173
+ "tray",
174
+ "ashcan",
175
+ "fan",
176
+ "pier",
177
+ "crt screen",
178
+ "plate",
179
+ "monitor",
180
+ "bulletin board",
181
+ "shower",
182
+ "radiator",
183
+ "glass",
184
+ "clock",
185
+ "flag",
186
+ ]
187
+
188
 
189
  def eval_model(args):
190
  # Model
 
193
  tokenizer = AutoTokenizer.from_pretrained(model_name)
194
 
195
  if "mpt" in model_name.lower():
196
+ model = LlavaMPTForCausalLM.from_pretrained(
197
+ model_name,
198
+ low_cpu_mem_usage=True,
199
+ torch_dtype=torch.float16,
200
+ use_cache=True,
201
+ ).cuda()
202
  else:
203
  # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
204
+ model = LlavaLlamaForCausalLM.from_pretrained(
205
+ model_name, torch_dtype=torch.float16, device_map="auto"
206
+ ) # .cuda()
207
+ image_processor = CLIPImageProcessor.from_pretrained(
208
+ model.config.mm_vision_tower, torch_dtype=torch.float16
209
+ )
210
 
211
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
212
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
213
  if mm_use_im_start_end:
214
+ tokenizer.add_tokens(
215
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
216
+ )
217
 
218
  vision_tower = model.get_model().vision_tower[0]
219
+ if vision_tower.device.type == "meta":
220
+ vision_tower = CLIPVisionModel.from_pretrained(
221
+ vision_tower.config._name_or_path,
222
+ torch_dtype=torch.float16,
223
+ low_cpu_mem_usage=True,
224
+ ).cuda()
225
  model.get_model().vision_tower[0] = vision_tower
226
  else:
227
+ vision_tower.to(device="cuda", dtype=torch.float16)
228
  vision_config = vision_tower.config
229
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
230
+ [DEFAULT_IMAGE_PATCH_TOKEN]
231
+ )[0]
232
  vision_config.use_im_start_end = mm_use_im_start_end
233
  if mm_use_im_start_end:
234
+ (
235
+ vision_config.im_start_token,
236
+ vision_config.im_end_token,
237
+ ) = tokenizer.convert_tokens_to_ids(
238
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
239
+ )
240
  image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
241
 
242
  # paths for all images
243
+ images = sorted(
244
+ glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")
245
+ )
246
  start, end = args.range.split(",")
247
  start, end = int(start), int(end)
248
  images = images[start:end]
249
  results = []
250
  for i, image_file in enumerate(tqdm.tqdm(images)):
 
251
  # if i == 2:
252
  # break
253
 
 
257
  print("i: {}, len(images): {}".format(i, len(images)))
258
 
259
  image = load_image(image_file)
260
+ image_tensor = image_processor.preprocess(image, return_tensors="pt")[
261
+ "pixel_values"
262
+ ][0]
263
  image_tensor = image_tensor.unsqueeze(0).half().cuda()
264
 
265
  prompt_list = []
 
280
 
281
  # qs = args.query
282
  if mm_use_im_start_end:
283
+ qs = (
284
+ qs
285
+ + "\n"
286
+ + DEFAULT_IM_START_TOKEN
287
+ + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
288
+ + DEFAULT_IM_END_TOKEN
289
+ )
290
  else:
291
+ qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
292
 
293
  if "v1" in model_name.lower():
294
  conv_mode = "llava_v1"
 
298
  conv_mode = "multimodal"
299
 
300
  if args.conv_mode is not None and conv_mode != args.conv_mode:
301
+ print(
302
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
303
+ conv_mode, args.conv_mode, args.conv_mode
304
+ )
305
+ )
306
  else:
307
  args.conv_mode = conv_mode
308
 
 
328
  images=image_tensor,
329
  do_sample=True,
330
  temperature=0.2,
331
+ max_new_tokens=512, # 1024,
332
+ stopping_criteria=[stopping_criteria],
333
+ )
334
 
335
  input_token_len = input_ids.shape[1]
336
+ n_diff_input_output = (
337
+ (input_ids != output_ids[:, :input_token_len]).sum().item()
338
+ )
339
  if n_diff_input_output > 0:
340
+ print(
341
+ f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
342
+ )
343
+ outputs = tokenizer.batch_decode(
344
+ output_ids[:, input_token_len:], skip_special_tokens=True
345
+ )[0]
346
  outputs = outputs.strip()
347
  if outputs.endswith(stop_str):
348
+ outputs = outputs[: -len(stop_str)]
349
  outputs = outputs.strip()
350
 
351
  # print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file))
 
354
  # results.append({'image_id': image_id, 'input': input_conv, 'output': outputs})
355
  output_list.append(outputs)
356
  image_id = image_file.split("/")[-1].split(".")[0]
357
+ with open(
358
+ "/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+"
359
+ ) as f:
360
+ json.dump(
361
+ {
362
+ "image_id": image_id,
363
+ "input_list": prompt_list,
364
+ "output_list": output_list,
365
+ },
366
+ f,
367
+ )
368
 
369
  # with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f:
370
  # json.dump(results, f)
371
 
372
+ # print(outputs)
373
+
374
 
375
  if __name__ == "__main__":
376
  parser = argparse.ArgumentParser()
model/llava/eval/summarize_gpt_review.py CHANGED
@@ -4,23 +4,25 @@ from collections import defaultdict
4
 
5
  import numpy as np
6
 
7
-
8
- if __name__ == '__main__':
9
  base_dir = "vqa/reviews/coco2014_val80"
10
- review_files = [x for x in os.listdir(base_dir) if x.endswith('.jsonl') and x.startswith('gpt4_text')]
 
 
 
 
11
 
12
  for review_file in sorted(review_files):
13
- config = review_file.replace('gpt4_text_', '').replace('.jsonl', '')
14
  scores = defaultdict(list)
15
- print(f'GPT-4 vs. {config}')
16
  with open(os.path.join(base_dir, review_file)) as f:
17
  for review_str in f:
18
  review = json.loads(review_str)
19
- scores[review['category']].append(review['tuple'])
20
- scores['all'].append(review['tuple'])
21
  for k, v in scores.items():
22
  stats = np.asarray(v).mean(0).tolist()
23
  stats = [round(x, 3) for x in stats]
24
- print(k, stats, round(stats[1]/stats[0]*100, 1))
25
- print('=================================')
26
-
 
4
 
5
  import numpy as np
6
 
7
+ if __name__ == "__main__":
 
8
  base_dir = "vqa/reviews/coco2014_val80"
9
+ review_files = [
10
+ x
11
+ for x in os.listdir(base_dir)
12
+ if x.endswith(".jsonl") and x.startswith("gpt4_text")
13
+ ]
14
 
15
  for review_file in sorted(review_files):
16
+ config = review_file.replace("gpt4_text_", "").replace(".jsonl", "")
17
  scores = defaultdict(list)
18
+ print(f"GPT-4 vs. {config}")
19
  with open(os.path.join(base_dir, review_file)) as f:
20
  for review_str in f:
21
  review = json.loads(review_str)
22
+ scores[review["category"]].append(review["tuple"])
23
+ scores["all"].append(review["tuple"])
24
  for k, v in scores.items():
25
  stats = np.asarray(v).mean(0).tolist()
26
  stats = [round(x, 3) for x in stats]
27
+ print(k, stats, round(stats[1] / stats[0] * 100, 1))
28
+ print("=================================")
 
model/llava/model/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .llava import LlavaLlamaForCausalLM, LlavaConfig
2
- from .llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
 
1
+ from .llava import LlavaConfig, LlavaLlamaForCausalLM
2
+ from .llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM
model/llava/model/apply_delta.py CHANGED
@@ -5,32 +5,40 @@ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~
5
  import argparse
6
 
7
  import torch
8
- from tqdm import tqdm
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from llava import LlavaLlamaForCausalLM
 
 
11
 
12
 
13
  def apply_delta(base_model_path, target_model_path, delta_path):
14
  print("Loading base model")
15
  base = AutoModelForCausalLM.from_pretrained(
16
- base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
17
 
18
  print("Loading delta")
19
- delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
 
20
  delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
 
22
  print("Applying delta")
23
  for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
  if name not in base.state_dict():
25
- assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
 
 
 
26
  continue
27
  if param.data.shape == base.state_dict()[name].shape:
28
  param.data += base.state_dict()[name]
29
  else:
30
- assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
- f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
 
 
32
  bparam = base.state_dict()[name]
33
- param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
 
35
  print("Saving target model")
36
  delta.save_pretrained(target_model_path)
 
5
  import argparse
6
 
7
  import torch
 
 
8
  from llava import LlavaLlamaForCausalLM
9
+ from tqdm import tqdm
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
 
13
  def apply_delta(base_model_path, target_model_path, delta_path):
14
  print("Loading base model")
15
  base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17
+ )
18
 
19
  print("Loading delta")
20
+ delta = LlavaLlamaForCausalLM.from_pretrained(
21
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
22
+ )
23
  delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
24
 
25
  print("Applying delta")
26
  for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
27
  if name not in base.state_dict():
28
+ assert name in [
29
+ "model.mm_projector.weight",
30
+ "model.mm_projector.bias",
31
+ ], f"{name} not in base model"
32
  continue
33
  if param.data.shape == base.state_dict()[name].shape:
34
  param.data += base.state_dict()[name]
35
  else:
36
+ assert name in [
37
+ "model.embed_tokens.weight",
38
+ "lm_head.weight",
39
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
40
  bparam = base.state_dict()[name]
41
+ param.data[: bparam.shape[0], : bparam.shape[1]] += bparam
42
 
43
  print("Saving target model")
44
  delta.save_pretrained(target_model_path)
model/llava/model/consolidate.py CHANGED
@@ -5,15 +5,17 @@ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_
5
  import argparse
6
 
7
  import torch
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
  from llava.model import *
10
  from llava.model.utils import auto_upgrade
 
11
 
12
 
13
  def consolidate_ckpt(src_path, dst_path):
14
  print("Loading model")
15
  auto_upgrade(src_path)
16
- src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
 
17
  src_tokenizer = AutoTokenizer.from_pretrained(src_path)
18
  src_model.save_pretrained(dst_path)
19
  src_tokenizer.save_pretrained(dst_path)
 
5
  import argparse
6
 
7
  import torch
 
8
  from llava.model import *
9
  from llava.model.utils import auto_upgrade
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
 
13
  def consolidate_ckpt(src_path, dst_path):
14
  print("Loading model")
15
  auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(
17
+ src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
18
+ )
19
  src_tokenizer = AutoTokenizer.from_pretrained(src_path)
20
  src_model.save_pretrained(dst_path)
21
  src_tokenizer.save_pretrained(dst_path)
model/llava/model/llava.py CHANGED
@@ -19,13 +19,11 @@ import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from torch.nn import CrossEntropyLoss
22
-
23
- from transformers import AutoConfig, AutoModelForCausalLM, \
24
- LlamaConfig, LlamaModel, LlamaForCausalLM, \
25
- CLIPVisionModel, CLIPImageProcessor
26
-
27
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
28
-
29
 
30
  DEFAULT_IMAGE_TOKEN = "<image>"
31
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
@@ -45,25 +43,33 @@ class LlavaLlamaModel(LlamaModel):
45
 
46
  if hasattr(config, "mm_vision_tower"):
47
  # HACK: for FSDP
48
- self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
 
49
 
50
  if hasattr(config, "use_mm_proj"):
51
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
52
 
53
- def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
54
- pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False, precision='bf16'):
 
 
 
 
 
 
55
  self.config.mm_vision_tower = vision_tower
56
 
57
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
58
 
59
- if not hasattr(self, 'vision_tower'):
60
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
61
  else:
62
  vision_tower = self.vision_tower[0]
63
  vision_tower.requires_grad_(False)
64
- if precision == 'bf16':
65
  vision_tower = vision_tower.to(torch.bfloat16)
66
- elif precision == 'fp16':
67
  vision_tower = vision_tower.to(torch.half)
68
  else:
69
  vision_tower = vision_tower.to(torch.float32)
@@ -77,17 +83,23 @@ class LlavaLlamaModel(LlamaModel):
77
  self.config.mm_hidden_size = vision_config.hidden_size
78
  self.config.mm_vision_select_layer = mm_vision_select_layer
79
 
80
- if not hasattr(self, 'mm_projector'):
81
- self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
 
 
82
 
83
  if pretrain_mm_mlp_adapter is not None:
84
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
85
- self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
 
 
 
 
86
 
87
  return dict(
88
  image_processor=image_processor,
89
  image_token_len=num_patches,
90
- vision_config=vision_config
91
  )
92
 
93
  def forward(
@@ -102,15 +114,18 @@ class LlavaLlamaModel(LlamaModel):
102
  images: Optional[torch.FloatTensor] = None,
103
  return_dict: Optional[bool] = None,
104
  ) -> Union[Tuple, BaseModelOutputWithPast]:
105
-
106
  # HACK: replace back original embeddings for LLaVA pretraining
107
- orig_embeds_params = getattr(self, 'orig_embeds_params', None)
108
-
109
  if inputs_embeds is None:
110
  inputs_embeds = self.embed_tokens(input_ids)
111
 
112
- vision_tower = getattr(self, 'vision_tower', None)
113
- if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
 
 
 
 
114
  # TODO: this is a modified multimodal LLM -- Haotian Liu
115
  vision_tower = vision_tower[0] # HACK: for FSDP
116
  with torch.no_grad():
@@ -118,26 +133,41 @@ class LlavaLlamaModel(LlamaModel):
118
  # variable length images
119
  image_features = []
120
  for image in images:
121
- image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
122
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
123
- select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
 
 
 
 
 
 
124
  image_feature = select_hidden_state[:, 1:]
125
  image_feature = image_feature.contiguous()
126
  image_features.append(image_feature)
127
  torch.cuda.empty_cache()
128
  else:
129
  image_forward_outs = vision_tower(images, output_hidden_states=True)
130
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
131
- select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
 
 
 
 
132
  image_features = select_hidden_state[:, 1:]
133
  image_features = image_features.contiguous()
134
  torch.cuda.empty_cache()
135
 
136
  if type(images) is list:
137
- image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
 
 
 
138
  else:
139
  image_features = self.mm_projector(image_features)
140
- dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
 
141
  dummy_image_features = self.mm_projector(dummy_image_features)
142
 
143
  new_input_embeds = []
@@ -145,48 +175,128 @@ class LlavaLlamaModel(LlamaModel):
145
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
146
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
147
  # multimodal LLM, but the current sample is not multimodal
148
- cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
 
 
149
  new_input_embeds.append(cur_input_embeds)
150
  cur_image_idx += 1
151
  continue
152
  if vision_tower.config.use_im_start_end:
153
  cur_image_features = image_features[cur_image_idx]
154
  num_patches = cur_image_features.shape[0]
155
- if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
156
- raise ValueError("The number of image start tokens and image end tokens should be the same.")
157
- image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
 
 
 
 
 
 
158
  for image_start_token_pos in image_start_tokens:
159
- cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
 
 
160
  num_patches = cur_image_features.shape[0]
161
- if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
162
- raise ValueError("The image end token should follow the image start token.")
 
 
 
 
 
163
  if orig_embeds_params is not None:
164
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  else:
166
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
 
 
 
 
 
 
 
 
 
167
  cur_image_idx += 1
168
  new_input_embeds.append(cur_new_input_embeds)
169
  else:
170
  cur_image_features = image_features[cur_image_idx]
171
  num_patches = cur_image_features.shape[0]
172
- if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
173
- raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
174
- masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
 
 
 
 
 
 
175
  mask_index_start = masked_indices[0]
176
- if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
177
- raise ValueError("The image patch tokens should be consecutive.")
 
 
 
 
 
 
 
 
 
 
178
  if orig_embeds_params is not None:
179
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
 
 
 
 
 
 
 
 
 
180
  else:
181
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
 
 
 
 
 
 
 
182
  new_input_embeds.append(cur_new_input_embeds)
183
  cur_image_idx += 1
184
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
185
  return super(LlavaLlamaModel, self).forward(
186
- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
187
- inputs_embeds=inputs_embeds, use_cache=use_cache,
188
- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
189
- return_dict=return_dict
 
 
 
 
190
  )
191
 
192
 
@@ -218,11 +328,19 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
218
  images: Optional[torch.FloatTensor] = None,
219
  return_dict: Optional[bool] = None,
220
  ) -> Union[Tuple, CausalLMOutputWithPast]:
221
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
222
  output_hidden_states = (
223
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
224
  )
225
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
226
 
227
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
228
  outputs = self.model(
@@ -234,7 +352,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
234
  output_attentions=output_attentions,
235
  output_hidden_states=output_hidden_states,
236
  return_dict=return_dict,
237
- images=images
238
  )
239
  hidden_states = outputs[0]
240
  logits = self.lm_head(hidden_states)
@@ -269,7 +387,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
269
  )
270
 
271
  def prepare_inputs_for_generation(
272
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
273
  ):
274
  if past_key_values:
275
  input_ids = input_ids[:, -1:]
@@ -291,16 +414,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
291
  )
292
  return model_inputs
293
 
294
- def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, num_new_tokens, device,
295
- tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
 
 
 
 
 
 
 
296
  vision_config = self.get_model().vision_tower[0].config
297
  vision_config.use_im_start_end = mm_use_im_start_end
298
 
299
  if mm_use_im_start_end:
300
  # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
301
-
302
  # self.resize_token_embeddings(len(tokenizer))
303
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
304
 
305
  # if num_new_tokens > 0:
306
  # input_embeddings = self.get_input_embeddings().weight.data
@@ -315,24 +450,35 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
315
  # output_embeddings[-num_new_tokens:] = output_embeddings_avg
316
 
317
  if tune_mm_mlp_adapter:
318
- self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
 
 
319
  for p in self.get_input_embeddings().parameters():
320
  p.requires_grad = True
321
  for p in self.get_output_embeddings().parameters():
322
  p.requires_grad = False
323
 
324
  if pretrain_mm_mlp_adapter:
325
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
326
- embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
 
 
327
  assert num_new_tokens == 2
328
  if input_embeddings.shape == embed_tokens_weight.shape:
329
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
 
 
330
  elif embed_tokens_weight.shape[0] == num_new_tokens:
331
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
332
  else:
333
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
 
 
 
 
 
 
334
 
335
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
336
 
337
  AutoConfig.register("llava", LlavaConfig)
338
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
 
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  from torch.nn import CrossEntropyLoss
22
+ from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
23
+ CLIPVisionModel, LlamaConfig, LlamaForCausalLM,
24
+ LlamaModel)
25
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
26
+ CausalLMOutputWithPast)
 
 
27
 
28
  DEFAULT_IMAGE_TOKEN = "<image>"
29
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
 
43
 
44
  if hasattr(config, "mm_vision_tower"):
45
  # HACK: for FSDP
46
+ self.vision_tower = [
47
+ CLIPVisionModel.from_pretrained(config.mm_vision_tower)
48
+ ]
49
 
50
  if hasattr(config, "use_mm_proj"):
51
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
52
 
53
+ def initialize_vision_modules(
54
+ self,
55
+ vision_tower,
56
+ mm_vision_select_layer,
57
+ pretrain_mm_mlp_adapter=None,
58
+ tune_mm_mlp_adapter=False,
59
+ precision="bf16",
60
+ ):
61
  self.config.mm_vision_tower = vision_tower
62
 
63
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
64
 
65
+ if not hasattr(self, "vision_tower"):
66
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
67
  else:
68
  vision_tower = self.vision_tower[0]
69
  vision_tower.requires_grad_(False)
70
+ if precision == "bf16":
71
  vision_tower = vision_tower.to(torch.bfloat16)
72
+ elif precision == "fp16":
73
  vision_tower = vision_tower.to(torch.half)
74
  else:
75
  vision_tower = vision_tower.to(torch.float32)
 
83
  self.config.mm_hidden_size = vision_config.hidden_size
84
  self.config.mm_vision_select_layer = mm_vision_select_layer
85
 
86
+ if not hasattr(self, "mm_projector"):
87
+ self.mm_projector = nn.Linear(
88
+ vision_config.hidden_size, self.config.hidden_size
89
+ )
90
 
91
  if pretrain_mm_mlp_adapter is not None:
92
+ mm_projector_weights = torch.load(
93
+ pretrain_mm_mlp_adapter, map_location="cpu"
94
+ )
95
+ self.mm_projector.load_state_dict(
96
+ {k.split(".")[-1]: v for k, v in mm_projector_weights.items()}
97
+ )
98
 
99
  return dict(
100
  image_processor=image_processor,
101
  image_token_len=num_patches,
102
+ vision_config=vision_config,
103
  )
104
 
105
  def forward(
 
114
  images: Optional[torch.FloatTensor] = None,
115
  return_dict: Optional[bool] = None,
116
  ) -> Union[Tuple, BaseModelOutputWithPast]:
 
117
  # HACK: replace back original embeddings for LLaVA pretraining
118
+ orig_embeds_params = getattr(self, "orig_embeds_params", None)
119
+
120
  if inputs_embeds is None:
121
  inputs_embeds = self.embed_tokens(input_ids)
122
 
123
+ vision_tower = getattr(self, "vision_tower", None)
124
+ if (
125
+ vision_tower is not None
126
+ and (input_ids.shape[1] != 1 or self.training)
127
+ and images is not None
128
+ ):
129
  # TODO: this is a modified multimodal LLM -- Haotian Liu
130
  vision_tower = vision_tower[0] # HACK: for FSDP
131
  with torch.no_grad():
 
133
  # variable length images
134
  image_features = []
135
  for image in images:
136
+ image_forward_out = vision_tower(
137
+ image.unsqueeze(0), output_hidden_states=True
138
+ )
139
+ select_hidden_state_layer = getattr(
140
+ self.config, "mm_vision_select_layer", -1
141
+ )
142
+ select_hidden_state = image_forward_out.hidden_states[
143
+ select_hidden_state_layer
144
+ ]
145
  image_feature = select_hidden_state[:, 1:]
146
  image_feature = image_feature.contiguous()
147
  image_features.append(image_feature)
148
  torch.cuda.empty_cache()
149
  else:
150
  image_forward_outs = vision_tower(images, output_hidden_states=True)
151
+ select_hidden_state_layer = getattr(
152
+ self.config, "mm_vision_select_layer", -1
153
+ )
154
+ select_hidden_state = image_forward_outs.hidden_states[
155
+ select_hidden_state_layer
156
+ ]
157
  image_features = select_hidden_state[:, 1:]
158
  image_features = image_features.contiguous()
159
  torch.cuda.empty_cache()
160
 
161
  if type(images) is list:
162
+ image_features = [
163
+ self.mm_projector(image_feature)[0]
164
+ for image_feature in image_features
165
+ ]
166
  else:
167
  image_features = self.mm_projector(image_features)
168
+ dummy_image_features = torch.zeros(
169
+ 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype
170
+ )
171
  dummy_image_features = self.mm_projector(dummy_image_features)
172
 
173
  new_input_embeds = []
 
175
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
176
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
177
  # multimodal LLM, but the current sample is not multimodal
178
+ cur_input_embeds = (
179
+ cur_input_embeds + (0.0 * dummy_image_features).sum()
180
+ )
181
  new_input_embeds.append(cur_input_embeds)
182
  cur_image_idx += 1
183
  continue
184
  if vision_tower.config.use_im_start_end:
185
  cur_image_features = image_features[cur_image_idx]
186
  num_patches = cur_image_features.shape[0]
187
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
188
+ cur_input_ids == vision_tower.config.im_end_token
189
+ ).sum():
190
+ raise ValueError(
191
+ "The number of image start tokens and image end tokens should be the same."
192
+ )
193
+ image_start_tokens = torch.where(
194
+ cur_input_ids == vision_tower.config.im_start_token
195
+ )[0]
196
  for image_start_token_pos in image_start_tokens:
197
+ cur_image_features = image_features[cur_image_idx].to(
198
+ device=cur_input_embeds.device
199
+ )
200
  num_patches = cur_image_features.shape[0]
201
+ if (
202
+ cur_input_ids[image_start_token_pos + num_patches + 1]
203
+ != vision_tower.config.im_end_token
204
+ ):
205
+ raise ValueError(
206
+ "The image end token should follow the image start token."
207
+ )
208
  if orig_embeds_params is not None:
209
+ cur_new_input_embeds = torch.cat(
210
+ (
211
+ cur_input_embeds[:image_start_token_pos].detach(),
212
+ cur_input_embeds[
213
+ image_start_token_pos : image_start_token_pos
214
+ + 1
215
+ ],
216
+ cur_image_features,
217
+ cur_input_embeds[
218
+ image_start_token_pos
219
+ + num_patches
220
+ + 1 : image_start_token_pos
221
+ + num_patches
222
+ + 2
223
+ ],
224
+ cur_input_embeds[
225
+ image_start_token_pos + num_patches + 2 :
226
+ ].detach(),
227
+ ),
228
+ dim=0,
229
+ )
230
  else:
231
+ cur_new_input_embeds = torch.cat(
232
+ (
233
+ cur_input_embeds[: image_start_token_pos + 1],
234
+ cur_image_features,
235
+ cur_input_embeds[
236
+ image_start_token_pos + num_patches + 1 :
237
+ ],
238
+ ),
239
+ dim=0,
240
+ )
241
  cur_image_idx += 1
242
  new_input_embeds.append(cur_new_input_embeds)
243
  else:
244
  cur_image_features = image_features[cur_image_idx]
245
  num_patches = cur_image_features.shape[0]
246
+ if (
247
+ cur_input_ids == vision_tower.config.im_patch_token
248
+ ).sum() != num_patches:
249
+ raise ValueError(
250
+ "The number of image patch tokens should be the same as the number of image patches."
251
+ )
252
+ masked_indices = torch.where(
253
+ cur_input_ids == vision_tower.config.im_patch_token
254
+ )[0]
255
  mask_index_start = masked_indices[0]
256
+ if (
257
+ masked_indices
258
+ != torch.arange(
259
+ mask_index_start,
260
+ mask_index_start + num_patches,
261
+ device=masked_indices.device,
262
+ dtype=masked_indices.dtype,
263
+ )
264
+ ).any():
265
+ raise ValueError(
266
+ "The image patch tokens should be consecutive."
267
+ )
268
  if orig_embeds_params is not None:
269
+ cur_new_input_embeds = torch.cat(
270
+ (
271
+ cur_input_embeds[:mask_index_start].detach(),
272
+ cur_image_features,
273
+ cur_input_embeds[
274
+ mask_index_start + num_patches :
275
+ ].detach(),
276
+ ),
277
+ dim=0,
278
+ )
279
  else:
280
+ cur_new_input_embeds = torch.cat(
281
+ (
282
+ cur_input_embeds[:mask_index_start],
283
+ cur_image_features,
284
+ cur_input_embeds[mask_index_start + num_patches :],
285
+ ),
286
+ dim=0,
287
+ )
288
  new_input_embeds.append(cur_new_input_embeds)
289
  cur_image_idx += 1
290
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
291
  return super(LlavaLlamaModel, self).forward(
292
+ input_ids=None,
293
+ attention_mask=attention_mask,
294
+ past_key_values=past_key_values,
295
+ inputs_embeds=inputs_embeds,
296
+ use_cache=use_cache,
297
+ output_attentions=output_attentions,
298
+ output_hidden_states=output_hidden_states,
299
+ return_dict=return_dict,
300
  )
301
 
302
 
 
328
  images: Optional[torch.FloatTensor] = None,
329
  return_dict: Optional[bool] = None,
330
  ) -> Union[Tuple, CausalLMOutputWithPast]:
331
+ output_attentions = (
332
+ output_attentions
333
+ if output_attentions is not None
334
+ else self.config.output_attentions
335
+ )
336
  output_hidden_states = (
337
+ output_hidden_states
338
+ if output_hidden_states is not None
339
+ else self.config.output_hidden_states
340
+ )
341
+ return_dict = (
342
+ return_dict if return_dict is not None else self.config.use_return_dict
343
  )
 
344
 
345
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
346
  outputs = self.model(
 
352
  output_attentions=output_attentions,
353
  output_hidden_states=output_hidden_states,
354
  return_dict=return_dict,
355
+ images=images,
356
  )
357
  hidden_states = outputs[0]
358
  logits = self.lm_head(hidden_states)
 
387
  )
388
 
389
  def prepare_inputs_for_generation(
390
+ self,
391
+ input_ids,
392
+ past_key_values=None,
393
+ attention_mask=None,
394
+ inputs_embeds=None,
395
+ **kwargs,
396
  ):
397
  if past_key_values:
398
  input_ids = input_ids[:, -1:]
 
414
  )
415
  return model_inputs
416
 
417
+ def initialize_vision_tokenizer(
418
+ self,
419
+ mm_use_im_start_end,
420
+ tokenizer,
421
+ num_new_tokens,
422
+ device,
423
+ tune_mm_mlp_adapter=False,
424
+ pretrain_mm_mlp_adapter=None,
425
+ ):
426
  vision_config = self.get_model().vision_tower[0].config
427
  vision_config.use_im_start_end = mm_use_im_start_end
428
 
429
  if mm_use_im_start_end:
430
  # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
431
+
432
  # self.resize_token_embeddings(len(tokenizer))
433
+ (
434
+ vision_config.im_start_token,
435
+ vision_config.im_end_token,
436
+ ) = tokenizer.convert_tokens_to_ids(
437
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
438
+ )
439
 
440
  # if num_new_tokens > 0:
441
  # input_embeddings = self.get_input_embeddings().weight.data
 
450
  # output_embeddings[-num_new_tokens:] = output_embeddings_avg
451
 
452
  if tune_mm_mlp_adapter:
453
+ self.get_model().orig_embeds_params = [
454
+ self.get_input_embeddings().weight.data.clone().to(device=device)
455
+ ]
456
  for p in self.get_input_embeddings().parameters():
457
  p.requires_grad = True
458
  for p in self.get_output_embeddings().parameters():
459
  p.requires_grad = False
460
 
461
  if pretrain_mm_mlp_adapter:
462
+ mm_projector_weights = torch.load(
463
+ pretrain_mm_mlp_adapter, map_location="cpu"
464
+ )
465
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
466
  assert num_new_tokens == 2
467
  if input_embeddings.shape == embed_tokens_weight.shape:
468
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
469
+ -num_new_tokens:
470
+ ]
471
  elif embed_tokens_weight.shape[0] == num_new_tokens:
472
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
473
  else:
474
+ raise ValueError(
475
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
476
+ )
477
+
478
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
479
+ [DEFAULT_IMAGE_PATCH_TOKEN]
480
+ )[0]
481
 
 
482
 
483
  AutoConfig.register("llava", LlavaConfig)
484
  AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
model/llava/model/llava_mpt.py CHANGED
@@ -13,24 +13,21 @@
13
  # limitations under the License.
14
 
15
 
16
- from typing import List, Optional, Tuple, Union
17
  import warnings
 
18
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  from torch.nn import CrossEntropyLoss
23
-
24
- import math
25
-
26
- from transformers import AutoConfig, AutoModelForCausalLM, \
27
- CLIPVisionModel, CLIPImageProcessor
28
-
29
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
30
 
31
  from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
32
 
33
-
34
  DEFAULT_IMAGE_TOKEN = "<image>"
35
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
36
  DEFAULT_IM_START_TOKEN = "<im_start>"
@@ -49,19 +46,26 @@ class LlavaMPTModel(MPTModel):
49
 
50
  if hasattr(config, "mm_vision_tower"):
51
  # HACK: for FSDP
52
- self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
 
 
53
  # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
54
 
55
  if hasattr(config, "use_mm_proj"):
56
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
57
 
58
- def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
59
- pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False):
 
 
 
 
 
60
  self.config.mm_vision_tower = vision_tower
61
 
62
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
63
 
64
- if not hasattr(self, 'vision_tower'):
65
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
66
  else:
67
  vision_tower = self.vision_tower[0]
@@ -76,23 +80,44 @@ class LlavaMPTModel(MPTModel):
76
  self.config.mm_hidden_size = vision_config.hidden_size
77
  self.config.mm_vision_select_layer = mm_vision_select_layer
78
 
79
- if not hasattr(self, 'mm_projector'):
80
- self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.d_model)
 
 
81
 
82
  if pretrain_mm_mlp_adapter is not None:
83
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
84
- self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items() if 'mm_projector' in k})
 
 
 
 
 
 
 
 
85
 
86
  return dict(
87
  image_processor=image_processor,
88
  image_token_len=num_patches,
89
- vision_config=vision_config
90
  )
91
 
92
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
93
-
 
 
 
 
 
 
 
 
 
 
 
94
  # HACK: replace back original embeddings for LLaVA pretraining
95
- orig_embeds_params = getattr(self, 'orig_embeds_params', None)
96
  # if orig_embeds_params is not None:
97
  # orig_embeds_params = orig_embeds_params[0]
98
  # with torch.no_grad():
@@ -100,8 +125,12 @@ class LlavaMPTModel(MPTModel):
100
 
101
  inputs_embeds = self.wte(input_ids)
102
 
103
- vision_tower = getattr(self, 'vision_tower', None)
104
- if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
 
 
 
 
105
  # TODO: this is a modified multimodal LLM -- Haotian Liu
106
  vision_tower = vision_tower[0] # HACK: for FSDP
107
  with torch.no_grad():
@@ -109,21 +138,36 @@ class LlavaMPTModel(MPTModel):
109
  # variable length images
110
  image_features = []
111
  for image in images:
112
- image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
113
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
114
- select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
 
 
 
 
 
 
115
  image_feature = select_hidden_state[:, 1:]
116
  image_features.append(image_feature)
117
  else:
118
  image_forward_outs = vision_tower(images, output_hidden_states=True)
119
- select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
120
- select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
 
 
 
 
121
  image_features = select_hidden_state[:, 1:]
122
  if type(images) is list:
123
- image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
 
 
 
124
  else:
125
  image_features = self.mm_projector(image_features)
126
- dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
 
127
  dummy_image_features = self.mm_projector(dummy_image_features)
128
 
129
  new_input_embeds = []
@@ -131,43 +175,130 @@ class LlavaMPTModel(MPTModel):
131
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
132
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
133
  # multimodal LLM, but the current sample is not multimodal
134
- cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
 
 
135
  new_input_embeds.append(cur_input_embeds)
136
  continue
137
  if vision_tower.config.use_im_start_end:
138
  cur_image_features = image_features[cur_image_idx]
139
  num_patches = cur_image_features.shape[0]
140
- if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
141
- raise ValueError("The number of image start tokens and image end tokens should be the same.")
142
- image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
 
 
 
 
 
 
143
  for image_start_token_pos in image_start_tokens:
144
- cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
 
 
145
  num_patches = cur_image_features.shape[0]
146
- if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
147
- raise ValueError("The image end token should follow the image start token.")
 
 
 
 
 
148
  if orig_embeds_params is not None:
149
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  else:
151
- cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
 
 
 
 
 
 
 
 
 
152
  cur_image_idx += 1
153
  new_input_embeds.append(cur_new_input_embeds)
154
  else:
155
  cur_image_features = image_features[cur_image_idx]
156
  num_patches = cur_image_features.shape[0]
157
- if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
158
- raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
159
- masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
 
 
 
 
 
 
160
  mask_index_start = masked_indices[0]
161
- if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
162
- raise ValueError("The image patch tokens should be consecutive.")
 
 
 
 
 
 
 
 
 
 
163
  if orig_embeds_params is not None:
164
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
 
 
 
 
 
 
 
 
 
165
  else:
166
- cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
 
 
 
 
 
 
 
167
  new_input_embeds.append(cur_new_input_embeds)
168
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
169
 
170
- return super(LlavaMPTModel, self).forward(input_ids=None, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, tok_emb=inputs_embeds)
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  class LlavaMPTForCausalLM(MPTForCausalLM):
@@ -178,16 +309,18 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
178
  super(MPTForCausalLM, self).__init__(config)
179
 
180
  if not config.tie_word_embeddings:
181
- raise ValueError('MPTForCausalLM only supports tied word embeddings')
182
  self.transformer = LlavaMPTModel(config)
183
  self.logit_scale = None
184
  if config.logit_scale is not None:
185
  logit_scale = config.logit_scale
186
  if isinstance(logit_scale, str):
187
- if logit_scale == 'inv_sqrt_d_model':
188
  logit_scale = 1 / math.sqrt(config.d_model)
189
  else:
190
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
 
 
191
  self.logit_scale = logit_scale
192
 
193
  def get_model(self):
@@ -197,28 +330,67 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
197
  if isinstance(module, LlavaMPTModel):
198
  module.gradient_checkpointing = value
199
 
200
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
201
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  use_cache = use_cache if use_cache is not None else self.config.use_cache
203
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, images=images)
 
 
 
 
 
 
 
 
 
 
 
204
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
205
  if self.logit_scale is not None:
206
  if self.logit_scale == 0:
207
- warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
 
208
  logits *= self.logit_scale
209
  loss = None
210
  if labels is not None:
211
  labels = torch.roll(labels, shifts=-1)
212
  labels[:, -1] = -100
213
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
214
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
 
 
 
 
 
 
 
215
 
216
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
217
  if inputs_embeds is not None:
218
- raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
219
- attention_mask = kwargs['attention_mask'].bool()
220
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
221
- raise NotImplementedError('MPT does not support generation with right padding.')
 
 
222
  if self.transformer.attn_uses_sequence_id and self.training:
223
  sequence_id = torch.zeros_like(input_ids[:1])
224
  else:
@@ -227,55 +399,91 @@ class LlavaMPTForCausalLM(MPTForCausalLM):
227
  input_ids = input_ids[:, -1].unsqueeze(-1)
228
  if self.transformer.prefix_lm:
229
  prefix_mask = torch.ones_like(attention_mask)
230
- if kwargs.get('use_cache') == False:
231
- raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
 
 
232
  else:
233
  prefix_mask = None
234
- return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
235
-
236
- def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
237
- tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  vision_config = self.get_model().vision_tower[0].config
239
  vision_config.use_im_start_end = mm_use_im_start_end
240
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
241
  self.resize_token_embeddings(len(tokenizer))
242
 
243
  if mm_use_im_start_end:
244
- num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
245
  self.resize_token_embeddings(len(tokenizer))
246
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
247
 
248
  if num_new_tokens > 0:
249
  input_embeddings = self.get_input_embeddings().weight.data
250
  output_embeddings = self.get_output_embeddings().weight.data
251
 
252
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
253
- dim=0, keepdim=True)
 
254
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
255
- dim=0, keepdim=True)
 
256
 
257
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
258
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
259
 
260
  if tune_mm_mlp_adapter:
261
- self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
 
 
262
  for p in self.get_input_embeddings().parameters():
263
  p.requires_grad = True
264
  for p in self.get_output_embeddings().parameters():
265
  p.requires_grad = False
266
 
267
  if pretrain_mm_mlp_adapter:
268
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
269
- embed_tokens_weight = mm_projector_weights['transformer.wte.weight']
 
 
270
  assert num_new_tokens == 2
271
  if input_embeddings.shape == embed_tokens_weight.shape:
272
- input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
 
 
273
  elif embed_tokens_weight.shape[0] == num_new_tokens:
274
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
275
  else:
276
- raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
 
 
 
 
 
 
277
 
278
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
279
 
280
  AutoConfig.register("llava_mpt", LlavaMPTConfig)
281
  AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
 
13
  # limitations under the License.
14
 
15
 
16
+ import math
17
  import warnings
18
+ from typing import List, Optional, Tuple, Union
19
 
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
  from torch.nn import CrossEntropyLoss
24
+ from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor,
25
+ CLIPVisionModel)
26
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast)
 
 
 
28
 
29
  from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
30
 
 
31
  DEFAULT_IMAGE_TOKEN = "<image>"
32
  DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
33
  DEFAULT_IM_START_TOKEN = "<im_start>"
 
46
 
47
  if hasattr(config, "mm_vision_tower"):
48
  # HACK: for FSDP
49
+ self.vision_tower = [
50
+ CLIPVisionModel.from_pretrained(config.mm_vision_tower)
51
+ ]
52
  # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
53
 
54
  if hasattr(config, "use_mm_proj"):
55
  self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model)
56
 
57
+ def initialize_vision_modules(
58
+ self,
59
+ vision_tower,
60
+ mm_vision_select_layer,
61
+ pretrain_mm_mlp_adapter=None,
62
+ tune_mm_mlp_adapter=False,
63
+ ):
64
  self.config.mm_vision_tower = vision_tower
65
 
66
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
67
 
68
+ if not hasattr(self, "vision_tower"):
69
  vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
70
  else:
71
  vision_tower = self.vision_tower[0]
 
80
  self.config.mm_hidden_size = vision_config.hidden_size
81
  self.config.mm_vision_select_layer = mm_vision_select_layer
82
 
83
+ if not hasattr(self, "mm_projector"):
84
+ self.mm_projector = nn.Linear(
85
+ vision_config.hidden_size, self.config.d_model
86
+ )
87
 
88
  if pretrain_mm_mlp_adapter is not None:
89
+ mm_projector_weights = torch.load(
90
+ pretrain_mm_mlp_adapter, map_location="cpu"
91
+ )
92
+ self.mm_projector.load_state_dict(
93
+ {
94
+ k.split(".")[-1]: v
95
+ for k, v in mm_projector_weights.items()
96
+ if "mm_projector" in k
97
+ }
98
+ )
99
 
100
  return dict(
101
  image_processor=image_processor,
102
  image_token_len=num_patches,
103
+ vision_config=vision_config,
104
  )
105
 
106
+ def forward(
107
+ self,
108
+ input_ids: torch.LongTensor,
109
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
110
+ attention_mask: Optional[torch.ByteTensor] = None,
111
+ prefix_mask: Optional[torch.ByteTensor] = None,
112
+ sequence_id: Optional[torch.LongTensor] = None,
113
+ return_dict: Optional[bool] = None,
114
+ output_attentions: Optional[bool] = None,
115
+ output_hidden_states: Optional[bool] = None,
116
+ use_cache: Optional[bool] = None,
117
+ images=None,
118
+ ):
119
  # HACK: replace back original embeddings for LLaVA pretraining
120
+ orig_embeds_params = getattr(self, "orig_embeds_params", None)
121
  # if orig_embeds_params is not None:
122
  # orig_embeds_params = orig_embeds_params[0]
123
  # with torch.no_grad():
 
125
 
126
  inputs_embeds = self.wte(input_ids)
127
 
128
+ vision_tower = getattr(self, "vision_tower", None)
129
+ if (
130
+ vision_tower is not None
131
+ and (input_ids.shape[1] != 1 or self.training)
132
+ and images is not None
133
+ ):
134
  # TODO: this is a modified multimodal LLM -- Haotian Liu
135
  vision_tower = vision_tower[0] # HACK: for FSDP
136
  with torch.no_grad():
 
138
  # variable length images
139
  image_features = []
140
  for image in images:
141
+ image_forward_out = vision_tower(
142
+ image.unsqueeze(0), output_hidden_states=True
143
+ )
144
+ select_hidden_state_layer = getattr(
145
+ self.config, "mm_vision_select_layer", -1
146
+ )
147
+ select_hidden_state = image_forward_out.hidden_states[
148
+ select_hidden_state_layer
149
+ ]
150
  image_feature = select_hidden_state[:, 1:]
151
  image_features.append(image_feature)
152
  else:
153
  image_forward_outs = vision_tower(images, output_hidden_states=True)
154
+ select_hidden_state_layer = getattr(
155
+ self.config, "mm_vision_select_layer", -1
156
+ )
157
+ select_hidden_state = image_forward_outs.hidden_states[
158
+ select_hidden_state_layer
159
+ ]
160
  image_features = select_hidden_state[:, 1:]
161
  if type(images) is list:
162
+ image_features = [
163
+ self.mm_projector(image_feature)[0]
164
+ for image_feature in image_features
165
+ ]
166
  else:
167
  image_features = self.mm_projector(image_features)
168
+ dummy_image_features = torch.zeros(
169
+ 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype
170
+ )
171
  dummy_image_features = self.mm_projector(dummy_image_features)
172
 
173
  new_input_embeds = []
 
175
  for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
176
  if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
177
  # multimodal LLM, but the current sample is not multimodal
178
+ cur_input_embeds = (
179
+ cur_input_embeds + (0.0 * dummy_image_features).sum()
180
+ )
181
  new_input_embeds.append(cur_input_embeds)
182
  continue
183
  if vision_tower.config.use_im_start_end:
184
  cur_image_features = image_features[cur_image_idx]
185
  num_patches = cur_image_features.shape[0]
186
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (
187
+ cur_input_ids == vision_tower.config.im_end_token
188
+ ).sum():
189
+ raise ValueError(
190
+ "The number of image start tokens and image end tokens should be the same."
191
+ )
192
+ image_start_tokens = torch.where(
193
+ cur_input_ids == vision_tower.config.im_start_token
194
+ )[0]
195
  for image_start_token_pos in image_start_tokens:
196
+ cur_image_features = image_features[cur_image_idx].to(
197
+ device=cur_input_embeds.device
198
+ )
199
  num_patches = cur_image_features.shape[0]
200
+ if (
201
+ cur_input_ids[image_start_token_pos + num_patches + 1]
202
+ != vision_tower.config.im_end_token
203
+ ):
204
+ raise ValueError(
205
+ "The image end token should follow the image start token."
206
+ )
207
  if orig_embeds_params is not None:
208
+ cur_new_input_embeds = torch.cat(
209
+ (
210
+ cur_input_embeds[:image_start_token_pos].detach(),
211
+ cur_input_embeds[
212
+ image_start_token_pos : image_start_token_pos
213
+ + 1
214
+ ],
215
+ cur_image_features,
216
+ cur_input_embeds[
217
+ image_start_token_pos
218
+ + num_patches
219
+ + 1 : image_start_token_pos
220
+ + num_patches
221
+ + 2
222
+ ],
223
+ cur_input_embeds[
224
+ image_start_token_pos + num_patches + 2 :
225
+ ].detach(),
226
+ ),
227
+ dim=0,
228
+ )
229
  else:
230
+ cur_new_input_embeds = torch.cat(
231
+ (
232
+ cur_input_embeds[: image_start_token_pos + 1],
233
+ cur_image_features,
234
+ cur_input_embeds[
235
+ image_start_token_pos + num_patches + 1 :
236
+ ],
237
+ ),
238
+ dim=0,
239
+ )
240
  cur_image_idx += 1
241
  new_input_embeds.append(cur_new_input_embeds)
242
  else:
243
  cur_image_features = image_features[cur_image_idx]
244
  num_patches = cur_image_features.shape[0]
245
+ if (
246
+ cur_input_ids == vision_tower.config.im_patch_token
247
+ ).sum() != num_patches:
248
+ raise ValueError(
249
+ "The number of image patch tokens should be the same as the number of image patches."
250
+ )
251
+ masked_indices = torch.where(
252
+ cur_input_ids == vision_tower.config.im_patch_token
253
+ )[0]
254
  mask_index_start = masked_indices[0]
255
+ if (
256
+ masked_indices
257
+ != torch.arange(
258
+ mask_index_start,
259
+ mask_index_start + num_patches,
260
+ device=masked_indices.device,
261
+ dtype=masked_indices.dtype,
262
+ )
263
+ ).any():
264
+ raise ValueError(
265
+ "The image patch tokens should be consecutive."
266
+ )
267
  if orig_embeds_params is not None:
268
+ cur_new_input_embeds = torch.cat(
269
+ (
270
+ cur_input_embeds[:mask_index_start].detach(),
271
+ cur_image_features,
272
+ cur_input_embeds[
273
+ mask_index_start + num_patches :
274
+ ].detach(),
275
+ ),
276
+ dim=0,
277
+ )
278
  else:
279
+ cur_new_input_embeds = torch.cat(
280
+ (
281
+ cur_input_embeds[:mask_index_start],
282
+ cur_image_features,
283
+ cur_input_embeds[mask_index_start + num_patches :],
284
+ ),
285
+ dim=0,
286
+ )
287
  new_input_embeds.append(cur_new_input_embeds)
288
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
289
 
290
+ return super(LlavaMPTModel, self).forward(
291
+ input_ids=None,
292
+ past_key_values=past_key_values,
293
+ attention_mask=attention_mask,
294
+ prefix_mask=prefix_mask,
295
+ sequence_id=sequence_id,
296
+ return_dict=return_dict,
297
+ output_attentions=output_attentions,
298
+ output_hidden_states=output_hidden_states,
299
+ use_cache=use_cache,
300
+ tok_emb=inputs_embeds,
301
+ )
302
 
303
 
304
  class LlavaMPTForCausalLM(MPTForCausalLM):
 
309
  super(MPTForCausalLM, self).__init__(config)
310
 
311
  if not config.tie_word_embeddings:
312
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
313
  self.transformer = LlavaMPTModel(config)
314
  self.logit_scale = None
315
  if config.logit_scale is not None:
316
  logit_scale = config.logit_scale
317
  if isinstance(logit_scale, str):
318
+ if logit_scale == "inv_sqrt_d_model":
319
  logit_scale = 1 / math.sqrt(config.d_model)
320
  else:
321
+ raise ValueError(
322
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
323
+ )
324
  self.logit_scale = logit_scale
325
 
326
  def get_model(self):
 
330
  if isinstance(module, LlavaMPTModel):
331
  module.gradient_checkpointing = value
332
 
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor,
336
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
337
+ attention_mask: Optional[torch.ByteTensor] = None,
338
+ prefix_mask: Optional[torch.ByteTensor] = None,
339
+ sequence_id: Optional[torch.LongTensor] = None,
340
+ labels: Optional[torch.LongTensor] = None,
341
+ return_dict: Optional[bool] = None,
342
+ output_attentions: Optional[bool] = None,
343
+ output_hidden_states: Optional[bool] = None,
344
+ use_cache: Optional[bool] = None,
345
+ images=None,
346
+ ):
347
+ return_dict = (
348
+ return_dict if return_dict is not None else self.config.return_dict
349
+ )
350
  use_cache = use_cache if use_cache is not None else self.config.use_cache
351
+ outputs = self.transformer(
352
+ input_ids=input_ids,
353
+ past_key_values=past_key_values,
354
+ attention_mask=attention_mask,
355
+ prefix_mask=prefix_mask,
356
+ sequence_id=sequence_id,
357
+ return_dict=return_dict,
358
+ output_attentions=output_attentions,
359
+ output_hidden_states=output_hidden_states,
360
+ use_cache=use_cache,
361
+ images=images,
362
+ )
363
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
364
  if self.logit_scale is not None:
365
  if self.logit_scale == 0:
366
+ warnings.warn(
367
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
368
+ )
369
  logits *= self.logit_scale
370
  loss = None
371
  if labels is not None:
372
  labels = torch.roll(labels, shifts=-1)
373
  labels[:, -1] = -100
374
+ loss = F.cross_entropy(
375
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
376
+ )
377
+ return CausalLMOutputWithPast(
378
+ loss=loss,
379
+ logits=logits,
380
+ past_key_values=outputs.past_key_values,
381
+ hidden_states=outputs.hidden_states,
382
+ )
383
 
384
+ def prepare_inputs_for_generation(
385
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
386
+ ):
387
  if inputs_embeds is not None:
388
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
389
+ attention_mask = kwargs["attention_mask"].bool()
390
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
391
+ raise NotImplementedError(
392
+ "MPT does not support generation with right padding."
393
+ )
394
  if self.transformer.attn_uses_sequence_id and self.training:
395
  sequence_id = torch.zeros_like(input_ids[:1])
396
  else:
 
399
  input_ids = input_ids[:, -1].unsqueeze(-1)
400
  if self.transformer.prefix_lm:
401
  prefix_mask = torch.ones_like(attention_mask)
402
+ if kwargs.get("use_cache") == False:
403
+ raise NotImplementedError(
404
+ "MPT with prefix_lm=True does not support use_cache=False."
405
+ )
406
  else:
407
  prefix_mask = None
408
+ return {
409
+ "input_ids": input_ids,
410
+ "attention_mask": attention_mask,
411
+ "prefix_mask": prefix_mask,
412
+ "sequence_id": sequence_id,
413
+ "past_key_values": past_key_values,
414
+ "use_cache": kwargs.get("use_cache", True),
415
+ "images": kwargs.get("images", None),
416
+ }
417
+
418
+ def initialize_vision_tokenizer(
419
+ self,
420
+ mm_use_im_start_end,
421
+ tokenizer,
422
+ device,
423
+ tune_mm_mlp_adapter=False,
424
+ pretrain_mm_mlp_adapter=None,
425
+ ):
426
  vision_config = self.get_model().vision_tower[0].config
427
  vision_config.use_im_start_end = mm_use_im_start_end
428
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
429
  self.resize_token_embeddings(len(tokenizer))
430
 
431
  if mm_use_im_start_end:
432
+ num_new_tokens = tokenizer.add_tokens(
433
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
434
+ )
435
  self.resize_token_embeddings(len(tokenizer))
436
+ (
437
+ vision_config.im_start_token,
438
+ vision_config.im_end_token,
439
+ ) = tokenizer.convert_tokens_to_ids(
440
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
441
+ )
442
 
443
  if num_new_tokens > 0:
444
  input_embeddings = self.get_input_embeddings().weight.data
445
  output_embeddings = self.get_output_embeddings().weight.data
446
 
447
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
448
+ dim=0, keepdim=True
449
+ )
450
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
451
+ dim=0, keepdim=True
452
+ )
453
 
454
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
455
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
456
 
457
  if tune_mm_mlp_adapter:
458
+ self.get_model().orig_embeds_params = [
459
+ self.get_input_embeddings().weight.data.clone().to(device=device)
460
+ ]
461
  for p in self.get_input_embeddings().parameters():
462
  p.requires_grad = True
463
  for p in self.get_output_embeddings().parameters():
464
  p.requires_grad = False
465
 
466
  if pretrain_mm_mlp_adapter:
467
+ mm_projector_weights = torch.load(
468
+ pretrain_mm_mlp_adapter, map_location="cpu"
469
+ )
470
+ embed_tokens_weight = mm_projector_weights["transformer.wte.weight"]
471
  assert num_new_tokens == 2
472
  if input_embeddings.shape == embed_tokens_weight.shape:
473
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
474
+ -num_new_tokens:
475
+ ]
476
  elif embed_tokens_weight.shape[0] == num_new_tokens:
477
  input_embeddings[-num_new_tokens:] = embed_tokens_weight
478
  else:
479
+ raise ValueError(
480
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
481
+ )
482
+
483
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
484
+ [DEFAULT_IMAGE_PATCH_TOKEN]
485
+ )[0]
486
 
 
487
 
488
  AutoConfig.register("llava_mpt", LlavaMPTConfig)
489
  AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
model/llava/model/make_delta.py CHANGED
@@ -5,31 +5,40 @@ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/mod
5
  import argparse
6
 
7
  import torch
8
- from tqdm import tqdm
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from llava.model.utils import auto_upgrade
 
 
11
 
12
 
13
  def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
  print("Loading base model")
15
  base = AutoModelForCausalLM.from_pretrained(
16
- base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
17
 
18
  print("Loading target model")
19
  auto_upgrade(target_model_path)
20
- target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
 
21
 
22
  print("Calculating delta")
23
  for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
  if name not in base.state_dict():
25
- assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
 
 
 
26
  continue
27
  if param.data.shape == base.state_dict()[name].shape:
28
  param.data -= base.state_dict()[name]
29
  else:
30
- assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
 
 
 
31
  bparam = base.state_dict()[name]
32
- param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
 
34
  print("Saving delta")
35
  if hub_repo_id:
@@ -49,4 +58,6 @@ if __name__ == "__main__":
49
  parser.add_argument("--hub-repo-id", type=str, default=None)
50
  args = parser.parse_args()
51
 
52
- make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
 
 
 
5
  import argparse
6
 
7
  import torch
 
 
8
  from llava.model.utils import auto_upgrade
9
+ from tqdm import tqdm
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
 
13
  def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
  print("Loading base model")
15
  base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
17
+ )
18
 
19
  print("Loading target model")
20
  auto_upgrade(target_model_path)
21
+ target = AutoModelForCausalLM.from_pretrained(
22
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
23
+ )
24
 
25
  print("Calculating delta")
26
  for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
27
  if name not in base.state_dict():
28
+ assert name in [
29
+ "model.mm_projector.weight",
30
+ "model.mm_projector.bias",
31
+ ], f"{name} not in base model"
32
  continue
33
  if param.data.shape == base.state_dict()[name].shape:
34
  param.data -= base.state_dict()[name]
35
  else:
36
+ assert name in [
37
+ "model.embed_tokens.weight",
38
+ "lm_head.weight",
39
+ ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
40
  bparam = base.state_dict()[name]
41
+ param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
42
 
43
  print("Saving delta")
44
  if hub_repo_id:
 
58
  parser.add_argument("--hub-repo-id", type=str, default=None)
59
  args = parser.parse_args()
60
 
61
+ make_delta(
62
+ args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
63
+ )
model/llava/model/mpt/adapt_tokenizer.py CHANGED
@@ -1,8 +1,12 @@
1
  from typing import Union
2
- from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
 
 
 
3
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4
  NUM_SENTINEL_TOKENS: int = 100
5
 
 
6
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7
  """Adds sentinel tokens and padding token (if missing).
8
 
@@ -12,16 +16,17 @@ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
12
  All added tokens are added as special tokens. No tokens are
13
  added if sentinel tokens and padding token already exist.
14
  """
15
- sentinels_to_add = [f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)]
16
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17
  if tokenizer.pad_token is None:
18
- tokenizer.add_tokens('<pad>', special_tokens=True)
19
- tokenizer.pad_token = '<pad>'
20
  assert tokenizer.pad_token_id is not None
21
- sentinels = ''.join([f'<extra_id_{i}>' for i in range(NUM_SENTINEL_TOKENS)])
22
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23
  tokenizer.sentinel_token_ids = _sentinel_token_ids
24
 
 
25
  class AutoTokenizerForMOD(AutoTokenizer):
26
  """AutoTokenizer + Adaptation for MOD.
27
 
@@ -38,4 +43,4 @@ class AutoTokenizerForMOD(AutoTokenizer):
38
  """See `AutoTokenizer.from_pretrained` docstring."""
39
  tokenizer = super().from_pretrained(*args, **kwargs)
40
  adapt_tokenizer_for_denoising(tokenizer)
41
- return tokenizer
 
1
  from typing import Union
2
+
3
+ from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
+ PreTrainedTokenizerFast)
5
+
6
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
7
  NUM_SENTINEL_TOKENS: int = 100
8
 
9
+
10
  def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
11
  """Adds sentinel tokens and padding token (if missing).
12
 
 
16
  All added tokens are added as special tokens. No tokens are
17
  added if sentinel tokens and padding token already exist.
18
  """
19
+ sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
20
  tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
21
  if tokenizer.pad_token is None:
22
+ tokenizer.add_tokens("<pad>", special_tokens=True)
23
+ tokenizer.pad_token = "<pad>"
24
  assert tokenizer.pad_token_id is not None
25
+ sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
26
  _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
27
  tokenizer.sentinel_token_ids = _sentinel_token_ids
28
 
29
+
30
  class AutoTokenizerForMOD(AutoTokenizer):
31
  """AutoTokenizer + Adaptation for MOD.
32
 
 
43
  """See `AutoTokenizer.from_pretrained` docstring."""
44
  tokenizer = super().from_pretrained(*args, **kwargs)
45
  adapt_tokenizer_for_denoising(tokenizer)
46
+ return tokenizer
model/llava/model/mpt/attention.py CHANGED
@@ -2,24 +2,45 @@
2
  import math
3
  import warnings
4
  from typing import Optional
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from torch import nn
 
9
  from .norm import LPLayerNorm
10
 
11
- def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
 
 
 
12
  if original_is_causal and num_query_tokens != num_key_tokens:
13
  if num_query_tokens != 1:
14
- raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
 
 
15
  else:
16
  return False
17
  return original_is_causal
18
 
19
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
20
- q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
21
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
22
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  min_val = torch.finfo(q.dtype).min
24
  (b, _, s_q, d) = q.shape
25
  s_k = k.size(-1)
@@ -27,13 +48,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
27
  softmax_scale = 1 / math.sqrt(d)
28
  attn_weight = q.matmul(k) * softmax_scale
29
  if attn_bias is not None:
30
- if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
31
- raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
 
 
 
 
 
 
32
  attn_weight = attn_weight + attn_bias
33
  if key_padding_mask is not None:
34
  if attn_bias is not None:
35
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
36
- attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
 
 
 
 
 
 
 
 
37
  if is_causal:
38
  s = max(s_q, s_k)
39
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
@@ -44,74 +79,146 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
44
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
45
  attn_weight = torch.softmax(attn_weight, dim=-1)
46
  if dropout_p:
47
- attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
 
 
48
  out = attn_weight.matmul(v)
49
- out = rearrange(out, 'b h s d -> b s (h d)')
50
  if needs_weights:
51
  return (out, attn_weight)
52
  return (out, None)
53
 
 
54
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
55
  for tensor in tensors:
56
  if tensor.dtype not in valid_dtypes:
57
- raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
 
 
58
  if not tensor.is_cuda:
59
- raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
 
 
 
60
 
61
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  from flash_attn import bert_padding, flash_attn_interface
64
  except:
65
- raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
  check_valid_inputs(query, key, value)
67
  if attn_bias is not None:
68
- raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
  (batch_size, seqlen) = query.shape[:2]
70
  if key_padding_mask is None:
71
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
72
- query_padding_mask = key_padding_mask[:, -query.size(1):]
73
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
74
- query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
75
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
76
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
 
 
 
 
 
 
77
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
78
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
 
 
79
  if multiquery:
80
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
81
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
 
 
82
  dropout_p = dropout_p if training else 0.0
83
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
- output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return (output, None)
87
 
88
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
  from flash_attn import flash_attn_triton
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
93
  check_valid_inputs(query, key, value)
94
  if dropout_p:
95
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
  if needs_weights:
97
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
98
  if key_padding_mask is not None:
99
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
 
 
 
 
 
 
100
  (b_size, s_k) = key_padding_mask.shape[:2]
101
  if attn_bias is None:
102
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
103
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
104
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
105
- key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
106
- value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
 
 
107
  if multiquery:
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
 
 
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
  return (output, None)
114
 
 
115
  class MultiheadAttention(nn.Module):
116
  """Multi-head self attention.
117
 
@@ -119,7 +226,18 @@ class MultiheadAttention(nn.Module):
119
  additive bias.
120
  """
121
 
122
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
 
 
 
 
 
 
 
 
 
 
 
123
  super().__init__()
124
  self.attn_impl = attn_impl
125
  self.clip_qkv = clip_qkv
@@ -137,21 +255,38 @@ class MultiheadAttention(nn.Module):
137
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
138
  self.q_ln = layernorm_class(self.d_model, device=device)
139
  self.k_ln = layernorm_class(self.d_model, device=device)
140
- if self.attn_impl == 'flash':
141
  self.attn_fn = flash_attn_fn
142
- elif self.attn_impl == 'triton':
143
  self.attn_fn = triton_flash_attn_fn
144
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
145
- elif self.attn_impl == 'torch':
 
 
 
 
 
146
  self.attn_fn = scaled_multihead_dot_product_attention
147
  if torch.cuda.is_available():
148
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
 
 
 
149
  else:
150
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
151
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
152
  self.out_proj._is_residual = True
153
 
154
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
155
  qkv = self.Wqkv(x)
156
  if self.clip_qkv:
157
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -167,10 +302,23 @@ class MultiheadAttention(nn.Module):
167
  value = torch.cat([past_key_value[1], value], dim=1)
168
  past_key_value = (key, value)
169
  if attn_bias is not None:
170
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
171
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
 
 
 
 
 
172
  return (self.out_proj(context), attn_weights, past_key_value)
173
 
 
174
  class MultiQueryAttention(nn.Module):
175
  """Multi-Query self attention.
176
 
@@ -178,7 +326,18 @@ class MultiQueryAttention(nn.Module):
178
  additive bias.
179
  """
180
 
181
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
 
 
 
 
 
 
 
 
 
 
 
182
  super().__init__()
183
  self.attn_impl = attn_impl
184
  self.clip_qkv = clip_qkv
@@ -197,25 +356,44 @@ class MultiQueryAttention(nn.Module):
197
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
198
  self.q_ln = layernorm_class(d_model, device=device)
199
  self.k_ln = layernorm_class(self.head_dim, device=device)
200
- if self.attn_impl == 'flash':
201
  self.attn_fn = flash_attn_fn
202
- elif self.attn_impl == 'triton':
203
  self.attn_fn = triton_flash_attn_fn
204
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
205
- elif self.attn_impl == 'torch':
 
 
 
 
 
206
  self.attn_fn = scaled_multihead_dot_product_attention
207
  if torch.cuda.is_available():
208
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
 
 
 
 
209
  else:
210
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
211
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
212
  self.out_proj._is_residual = True
213
 
214
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
215
  qkv = self.Wqkv(x)
216
  if self.clip_qkv:
217
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
218
- (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
 
 
219
  key_padding_mask = attention_mask
220
  if self.qk_ln:
221
  dtype = query.dtype
@@ -227,14 +405,30 @@ class MultiQueryAttention(nn.Module):
227
  value = torch.cat([past_key_value[1], value], dim=1)
228
  past_key_value = (key, value)
229
  if attn_bias is not None:
230
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
231
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return (self.out_proj(context), attn_weights, past_key_value)
233
 
234
- def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
235
- if attn_impl == 'flash':
 
 
 
236
  return None
237
- elif attn_impl in ['torch', 'triton']:
238
  if alibi:
239
  if (prefix_lm or not causal) or use_sequence_id:
240
  return (1, n_heads, seq_len, seq_len)
@@ -243,18 +437,31 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
243
  return (1, 1, seq_len, seq_len)
244
  return None
245
  else:
246
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
247
 
248
- def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
249
- if attn_impl == 'flash':
 
 
 
250
  return None
251
- elif attn_impl in ['torch', 'triton']:
252
  if alibi:
253
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
254
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
 
 
 
 
 
 
 
 
 
255
  return attn_bias
256
  else:
257
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
258
 
259
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
260
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
@@ -265,12 +472,24 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
265
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
266
  return slopes.view(1, n_heads, 1, 1)
267
 
268
- def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
269
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
 
 
 
 
 
270
  if full:
271
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
 
 
272
  alibi_bias = alibi_bias.abs().mul(-1)
273
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
274
  alibi_bias = alibi_bias * slopes
275
  return alibi_bias.to(dtype=dtype)
276
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
 
 
 
 
 
2
  import math
3
  import warnings
4
  from typing import Optional
5
+
6
  import torch
7
  import torch.nn as nn
8
  from einops import rearrange
9
  from torch import nn
10
+
11
  from .norm import LPLayerNorm
12
 
13
+
14
+ def _reset_is_causal(
15
+ num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
16
+ ):
17
  if original_is_causal and num_query_tokens != num_key_tokens:
18
  if num_query_tokens != 1:
19
+ raise NotImplementedError(
20
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
21
+ )
22
  else:
23
  return False
24
  return original_is_causal
25
 
26
+
27
+ def scaled_multihead_dot_product_attention(
28
+ query,
29
+ key,
30
+ value,
31
+ n_heads,
32
+ softmax_scale=None,
33
+ attn_bias=None,
34
+ key_padding_mask=None,
35
+ is_causal=False,
36
+ dropout_p=0.0,
37
+ training=False,
38
+ needs_weights=False,
39
+ multiquery=False,
40
+ ):
41
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
42
+ k = rearrange(key, "b s (h d) -> b h d s", h=1 if multiquery else n_heads)
43
+ v = rearrange(value, "b s (h d) -> b h s d", h=1 if multiquery else n_heads)
44
  min_val = torch.finfo(q.dtype).min
45
  (b, _, s_q, d) = q.shape
46
  s_k = k.size(-1)
 
48
  softmax_scale = 1 / math.sqrt(d)
49
  attn_weight = q.matmul(k) * softmax_scale
50
  if attn_bias is not None:
51
+ if (
52
+ attn_bias.size(-1) != 1
53
+ and attn_bias.size(-1) != s_k
54
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
55
+ ):
56
+ raise RuntimeError(
57
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
58
+ )
59
  attn_weight = attn_weight + attn_bias
60
  if key_padding_mask is not None:
61
  if attn_bias is not None:
62
+ warnings.warn(
63
+ "Propogating key_padding_mask to the attention module "
64
+ + "and applying it within the attention module can cause "
65
+ + "unneccessary computation/memory usage. Consider integrating "
66
+ + "into attn_bias once and passing that to each attention "
67
+ + "module instead."
68
+ )
69
+ attn_weight = attn_weight.masked_fill(
70
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val
71
+ )
72
  if is_causal:
73
  s = max(s_q, s_k)
74
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
 
79
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
80
  attn_weight = torch.softmax(attn_weight, dim=-1)
81
  if dropout_p:
82
+ attn_weight = torch.nn.functional.dropout(
83
+ attn_weight, p=dropout_p, training=training, inplace=True
84
+ )
85
  out = attn_weight.matmul(v)
86
+ out = rearrange(out, "b h s d -> b s (h d)")
87
  if needs_weights:
88
  return (out, attn_weight)
89
  return (out, None)
90
 
91
+
92
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
93
  for tensor in tensors:
94
  if tensor.dtype not in valid_dtypes:
95
+ raise TypeError(
96
+ f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
97
+ )
98
  if not tensor.is_cuda:
99
+ raise TypeError(
100
+ f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
101
+ )
102
+
103
 
104
+ def flash_attn_fn(
105
+ query,
106
+ key,
107
+ value,
108
+ n_heads,
109
+ softmax_scale=None,
110
+ attn_bias=None,
111
+ key_padding_mask=None,
112
+ is_causal=False,
113
+ dropout_p=0.0,
114
+ training=False,
115
+ needs_weights=False,
116
+ multiquery=False,
117
+ ):
118
  try:
119
  from flash_attn import bert_padding, flash_attn_interface
120
  except:
121
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
122
  check_valid_inputs(query, key, value)
123
  if attn_bias is not None:
124
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
125
  (batch_size, seqlen) = query.shape[:2]
126
  if key_padding_mask is None:
127
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
128
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
129
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
130
+ query, query_padding_mask
131
+ )
132
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
133
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
134
+ key, key_padding_mask
135
+ )
136
+ key_unpad = rearrange(
137
+ key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
138
+ )
139
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
140
+ value_unpad = rearrange(
141
+ value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
142
+ )
143
  if multiquery:
144
  key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
145
+ value_unpad = value_unpad.expand(
146
+ value_unpad.size(0), n_heads, value_unpad.size(-1)
147
+ )
148
  dropout_p = dropout_p if training else 0.0
149
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
150
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
151
+ query_unpad,
152
+ key_unpad,
153
+ value_unpad,
154
+ cu_seqlens_q,
155
+ cu_seqlens_k,
156
+ max_seqlen_q,
157
+ max_seqlen_k,
158
+ dropout_p,
159
+ softmax_scale=softmax_scale,
160
+ causal=reset_is_causal,
161
+ return_attn_probs=needs_weights,
162
+ )
163
+ output = bert_padding.pad_input(
164
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
165
+ )
166
  return (output, None)
167
 
168
+
169
+ def triton_flash_attn_fn(
170
+ query,
171
+ key,
172
+ value,
173
+ n_heads,
174
+ softmax_scale=None,
175
+ attn_bias=None,
176
+ key_padding_mask=None,
177
+ is_causal=False,
178
+ dropout_p=0.0,
179
+ training=False,
180
+ needs_weights=False,
181
+ multiquery=False,
182
+ ):
183
  try:
184
  from flash_attn import flash_attn_triton
185
  except:
186
+ raise RuntimeError(
187
+ "Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202"
188
+ )
189
  check_valid_inputs(query, key, value)
190
  if dropout_p:
191
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
192
  if needs_weights:
193
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
194
  if key_padding_mask is not None:
195
+ warnings.warn(
196
+ "Propagating key_padding_mask to the attention module "
197
+ + "and applying it within the attention module can cause "
198
+ + "unnecessary computation/memory usage. Consider integrating "
199
+ + "into attn_bias once and passing that to each attention "
200
+ + "module instead."
201
+ )
202
  (b_size, s_k) = key_padding_mask.shape[:2]
203
  if attn_bias is None:
204
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
205
+ attn_bias = attn_bias.masked_fill(
206
+ ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
207
+ )
208
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
209
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
210
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
211
  if multiquery:
212
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
213
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
214
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
215
+ attn_output = flash_attn_triton.flash_attn_func(
216
+ query, key, value, attn_bias, reset_is_causal, softmax_scale
217
+ )
218
  output = attn_output.view(*attn_output.shape[:2], -1)
219
  return (output, None)
220
 
221
+
222
  class MultiheadAttention(nn.Module):
223
  """Multi-head self attention.
224
 
 
226
  additive bias.
227
  """
228
 
229
+ def __init__(
230
+ self,
231
+ d_model: int,
232
+ n_heads: int,
233
+ attn_impl: str = "triton",
234
+ clip_qkv: Optional[float] = None,
235
+ qk_ln: bool = False,
236
+ softmax_scale: Optional[float] = None,
237
+ attn_pdrop: float = 0.0,
238
+ low_precision_layernorm: bool = False,
239
+ device: Optional[str] = None,
240
+ ):
241
  super().__init__()
242
  self.attn_impl = attn_impl
243
  self.clip_qkv = clip_qkv
 
255
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
256
  self.q_ln = layernorm_class(self.d_model, device=device)
257
  self.k_ln = layernorm_class(self.d_model, device=device)
258
+ if self.attn_impl == "flash":
259
  self.attn_fn = flash_attn_fn
260
+ elif self.attn_impl == "triton":
261
  self.attn_fn = triton_flash_attn_fn
262
+ warnings.warn(
263
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
264
+ + "it uses more memory. When training larger models this can trigger "
265
+ + "alloc retries which hurts performance. If encountered, we recommend "
266
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
267
+ )
268
+ elif self.attn_impl == "torch":
269
  self.attn_fn = scaled_multihead_dot_product_attention
270
  if torch.cuda.is_available():
271
+ warnings.warn(
272
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
273
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
274
+ + "we recommend using `attn_impl: triton`."
275
+ )
276
  else:
277
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
278
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
279
  self.out_proj._is_residual = True
280
 
281
+ def forward(
282
+ self,
283
+ x,
284
+ past_key_value=None,
285
+ attn_bias=None,
286
+ attention_mask=None,
287
+ is_causal=True,
288
+ needs_weights=False,
289
+ ):
290
  qkv = self.Wqkv(x)
291
  if self.clip_qkv:
292
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
302
  value = torch.cat([past_key_value[1], value], dim=1)
303
  past_key_value = (key, value)
304
  if attn_bias is not None:
305
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
306
+ (context, attn_weights) = self.attn_fn(
307
+ query,
308
+ key,
309
+ value,
310
+ self.n_heads,
311
+ softmax_scale=self.softmax_scale,
312
+ attn_bias=attn_bias,
313
+ key_padding_mask=key_padding_mask,
314
+ is_causal=is_causal,
315
+ dropout_p=self.attn_dropout_p,
316
+ training=self.training,
317
+ needs_weights=needs_weights,
318
+ )
319
  return (self.out_proj(context), attn_weights, past_key_value)
320
 
321
+
322
  class MultiQueryAttention(nn.Module):
323
  """Multi-Query self attention.
324
 
 
326
  additive bias.
327
  """
328
 
329
+ def __init__(
330
+ self,
331
+ d_model: int,
332
+ n_heads: int,
333
+ attn_impl: str = "triton",
334
+ clip_qkv: Optional[float] = None,
335
+ qk_ln: bool = False,
336
+ softmax_scale: Optional[float] = None,
337
+ attn_pdrop: float = 0.0,
338
+ low_precision_layernorm: bool = False,
339
+ device: Optional[str] = None,
340
+ ):
341
  super().__init__()
342
  self.attn_impl = attn_impl
343
  self.clip_qkv = clip_qkv
 
356
  layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
357
  self.q_ln = layernorm_class(d_model, device=device)
358
  self.k_ln = layernorm_class(self.head_dim, device=device)
359
+ if self.attn_impl == "flash":
360
  self.attn_fn = flash_attn_fn
361
+ elif self.attn_impl == "triton":
362
  self.attn_fn = triton_flash_attn_fn
363
+ warnings.warn(
364
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
365
+ + "it uses more memory. When training larger models this can trigger "
366
+ + "alloc retries which hurts performance. If encountered, we recommend "
367
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
368
+ )
369
+ elif self.attn_impl == "torch":
370
  self.attn_fn = scaled_multihead_dot_product_attention
371
  if torch.cuda.is_available():
372
+ warnings.warn(
373
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
374
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
375
+ + "we recommend using `attn_impl: triton`."
376
+ )
377
  else:
378
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
379
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
380
  self.out_proj._is_residual = True
381
 
382
+ def forward(
383
+ self,
384
+ x,
385
+ past_key_value=None,
386
+ attn_bias=None,
387
+ attention_mask=None,
388
+ is_causal=True,
389
+ needs_weights=False,
390
+ ):
391
  qkv = self.Wqkv(x)
392
  if self.clip_qkv:
393
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
394
+ (query, key, value) = qkv.split(
395
+ [self.d_model, self.head_dim, self.head_dim], dim=2
396
+ )
397
  key_padding_mask = attention_mask
398
  if self.qk_ln:
399
  dtype = query.dtype
 
405
  value = torch.cat([past_key_value[1], value], dim=1)
406
  past_key_value = (key, value)
407
  if attn_bias is not None:
408
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
409
+ (context, attn_weights) = self.attn_fn(
410
+ query,
411
+ key,
412
+ value,
413
+ self.n_heads,
414
+ softmax_scale=self.softmax_scale,
415
+ attn_bias=attn_bias,
416
+ key_padding_mask=key_padding_mask,
417
+ is_causal=is_causal,
418
+ dropout_p=self.attn_dropout_p,
419
+ training=self.training,
420
+ needs_weights=needs_weights,
421
+ multiquery=True,
422
+ )
423
  return (self.out_proj(context), attn_weights, past_key_value)
424
 
425
+
426
+ def attn_bias_shape(
427
+ attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
428
+ ):
429
+ if attn_impl == "flash":
430
  return None
431
+ elif attn_impl in ["torch", "triton"]:
432
  if alibi:
433
  if (prefix_lm or not causal) or use_sequence_id:
434
  return (1, n_heads, seq_len, seq_len)
 
437
  return (1, 1, seq_len, seq_len)
438
  return None
439
  else:
440
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
441
 
442
+
443
+ def build_attn_bias(
444
+ attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
445
+ ):
446
+ if attn_impl == "flash":
447
  return None
448
+ elif attn_impl in ["torch", "triton"]:
449
  if alibi:
450
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
451
+ attn_bias = attn_bias.add(
452
+ build_alibi_bias(
453
+ n_heads,
454
+ seq_len,
455
+ full=not causal,
456
+ alibi_bias_max=alibi_bias_max,
457
+ device=device,
458
+ dtype=dtype,
459
+ )
460
+ )
461
  return attn_bias
462
  else:
463
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
464
+
465
 
466
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
467
  _n_heads = 2 ** math.ceil(math.log2(n_heads))
 
472
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
473
  return slopes.view(1, n_heads, 1, 1)
474
 
475
+
476
+ def build_alibi_bias(
477
+ n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
478
+ ):
479
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
480
+ 1, 1, 1, seq_len
481
+ )
482
  if full:
483
+ alibi_bias = alibi_bias - torch.arange(
484
+ 1 - seq_len, 1, dtype=torch.int32, device=device
485
+ ).view(1, 1, seq_len, 1)
486
  alibi_bias = alibi_bias.abs().mul(-1)
487
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
488
  alibi_bias = alibi_bias * slopes
489
  return alibi_bias.to(dtype=dtype)
490
+
491
+
492
+ ATTN_CLASS_REGISTRY = {
493
+ "multihead_attention": MultiheadAttention,
494
+ "multiquery_attention": MultiQueryAttention,
495
+ }
model/llava/model/mpt/blocks.py CHANGED
@@ -1,41 +1,90 @@
1
  """GPT Blocks used for the GPT Model."""
2
  from typing import Dict, Optional, Tuple
 
3
  import torch
4
  import torch.nn as nn
 
5
  from .attention import ATTN_CLASS_REGISTRY
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
- class MPTMLP(nn.Module):
9
 
10
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
 
 
 
11
  super().__init__()
12
  self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
- self.act = nn.GELU(approximate='none')
14
  self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15
  self.down_proj._is_residual = True
16
 
17
  def forward(self, x):
18
  return self.down_proj(self.act(self.up_proj(x)))
19
 
20
- class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
- attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
 
 
 
 
 
 
 
 
 
29
  self.norm_2 = norm_class(d_model, device=device)
30
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
 
 
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
 
 
 
 
 
 
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
  from typing import Dict, Optional, Tuple
3
+
4
  import torch
5
  import torch.nn as nn
6
+
7
  from .attention import ATTN_CLASS_REGISTRY
8
  from .norm import NORM_CLASS_REGISTRY
9
 
 
10
 
11
+ class MPTMLP(nn.Module):
12
+ def __init__(
13
+ self, d_model: int, expansion_ratio: int, device: Optional[str] = None
14
+ ):
15
  super().__init__()
16
  self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
17
+ self.act = nn.GELU(approximate="none")
18
  self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
19
  self.down_proj._is_residual = True
20
 
21
  def forward(self, x):
22
  return self.down_proj(self.act(self.up_proj(x)))
23
 
 
24
 
25
+ class MPTBlock(nn.Module):
26
+ def __init__(
27
+ self,
28
+ d_model: int,
29
+ n_heads: int,
30
+ expansion_ratio: int,
31
+ attn_config: Dict = {
32
+ "attn_type": "multihead_attention",
33
+ "attn_pdrop": 0.0,
34
+ "attn_impl": "triton",
35
+ "qk_ln": False,
36
+ "clip_qkv": None,
37
+ "softmax_scale": None,
38
+ "prefix_lm": False,
39
+ "attn_uses_sequence_id": False,
40
+ "alibi": False,
41
+ "alibi_bias_max": 8,
42
+ },
43
+ resid_pdrop: float = 0.0,
44
+ norm_type: str = "low_precision_layernorm",
45
+ device: Optional[str] = None,
46
+ **kwargs
47
+ ):
48
  del kwargs
49
  super().__init__()
50
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
51
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
52
  self.norm_1 = norm_class(d_model, device=device)
53
+ self.attn = attn_class(
54
+ attn_impl=attn_config["attn_impl"],
55
+ clip_qkv=attn_config["clip_qkv"],
56
+ qk_ln=attn_config["qk_ln"],
57
+ softmax_scale=attn_config["softmax_scale"],
58
+ attn_pdrop=attn_config["attn_pdrop"],
59
+ d_model=d_model,
60
+ n_heads=n_heads,
61
+ device=device,
62
+ )
63
  self.norm_2 = norm_class(d_model, device=device)
64
+ self.ffn = MPTMLP(
65
+ d_model=d_model, expansion_ratio=expansion_ratio, device=device
66
+ )
67
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
68
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
69
 
70
+ def forward(
71
+ self,
72
+ x: torch.Tensor,
73
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
74
+ attn_bias: Optional[torch.Tensor] = None,
75
+ attention_mask: Optional[torch.ByteTensor] = None,
76
+ is_causal: bool = True,
77
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
78
  a = self.norm_1(x)
79
+ (b, _, past_key_value) = self.attn(
80
+ a,
81
+ past_key_value=past_key_value,
82
+ attn_bias=attn_bias,
83
+ attention_mask=attention_mask,
84
+ is_causal=is_causal,
85
+ )
86
  x = x + self.resid_attn_dropout(b)
87
  m = self.norm_2(x)
88
  n = self.ffn(m)
89
  x = x + self.resid_ffn_dropout(n)
90
+ return (x, past_key_value)
model/llava/model/mpt/configuration_mpt.py CHANGED
@@ -1,13 +1,52 @@
1
  """A HuggingFace-style model configuration."""
2
  from typing import Dict, Optional, Union
 
3
  from transformers import PretrainedConfig
4
- attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class MPTConfig(PretrainedConfig):
8
- model_type = 'mpt'
9
 
10
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """The MPT configuration class.
12
 
13
  Args:
@@ -80,39 +119,76 @@ class MPTConfig(PretrainedConfig):
80
  self.norm_type = norm_type
81
  self.use_cache = use_cache
82
  self.init_config = init_config
83
- if 'name' in kwargs:
84
- del kwargs['name']
85
- if 'loss_fn' in kwargs:
86
- del kwargs['loss_fn']
87
  super().__init__(**kwargs)
88
  self._validate_config()
89
 
90
  def _set_config_defaults(self, config, config_defaults):
91
- for (k, v) in config_defaults.items():
92
  if k not in config:
93
  config[k] = v
94
  return config
95
 
96
  def _validate_config(self):
97
- self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98
- self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
 
 
 
 
99
  if self.d_model % self.n_heads != 0:
100
- raise ValueError('d_model must be divisible by n_heads')
101
- if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102
- raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
 
 
 
 
 
 
 
 
 
 
 
104
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106
- raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107
- if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108
- raise NotImplementedError('alibi only implemented with torch and triton attention.')
109
- if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110
- raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112
- raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115
- if self.init_config.get('name', None) is None:
116
- raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117
- if not self.learned_pos_emb and (not self.attn_config['alibi']):
118
- raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
 
 
 
 
 
 
 
 
 
1
  """A HuggingFace-style model configuration."""
2
  from typing import Dict, Optional, Union
3
+
4
  from transformers import PretrainedConfig
5
+
6
+ attn_config_defaults: Dict = {
7
+ "attn_type": "multihead_attention",
8
+ "attn_pdrop": 0.0,
9
+ "attn_impl": "triton",
10
+ "qk_ln": False,
11
+ "clip_qkv": None,
12
+ "softmax_scale": None,
13
+ "prefix_lm": False,
14
+ "attn_uses_sequence_id": False,
15
+ "alibi": False,
16
+ "alibi_bias_max": 8,
17
+ }
18
+ init_config_defaults: Dict = {
19
+ "name": "kaiming_normal_",
20
+ "fan_mode": "fan_in",
21
+ "init_nonlinearity": "relu",
22
+ }
23
+
24
 
25
  class MPTConfig(PretrainedConfig):
26
+ model_type = "mpt"
27
 
28
+ def __init__(
29
+ self,
30
+ d_model: int = 2048,
31
+ n_heads: int = 16,
32
+ n_layers: int = 24,
33
+ expansion_ratio: int = 4,
34
+ max_seq_len: int = 2048,
35
+ vocab_size: int = 50368,
36
+ resid_pdrop: float = 0.0,
37
+ emb_pdrop: float = 0.0,
38
+ learned_pos_emb: bool = True,
39
+ attn_config: Dict = attn_config_defaults,
40
+ init_device: str = "cpu",
41
+ logit_scale: Optional[Union[float, str]] = None,
42
+ no_bias: bool = False,
43
+ verbose: int = 0,
44
+ embedding_fraction: float = 1.0,
45
+ norm_type: str = "low_precision_layernorm",
46
+ use_cache: bool = False,
47
+ init_config: Dict = init_config_defaults,
48
+ **kwargs,
49
+ ):
50
  """The MPT configuration class.
51
 
52
  Args:
 
119
  self.norm_type = norm_type
120
  self.use_cache = use_cache
121
  self.init_config = init_config
122
+ if "name" in kwargs:
123
+ del kwargs["name"]
124
+ if "loss_fn" in kwargs:
125
+ del kwargs["loss_fn"]
126
  super().__init__(**kwargs)
127
  self._validate_config()
128
 
129
  def _set_config_defaults(self, config, config_defaults):
130
+ for k, v in config_defaults.items():
131
  if k not in config:
132
  config[k] = v
133
  return config
134
 
135
  def _validate_config(self):
136
+ self.attn_config = self._set_config_defaults(
137
+ self.attn_config, attn_config_defaults
138
+ )
139
+ self.init_config = self._set_config_defaults(
140
+ self.init_config, init_config_defaults
141
+ )
142
  if self.d_model % self.n_heads != 0:
143
+ raise ValueError("d_model must be divisible by n_heads")
144
+ if any(
145
+ (
146
+ prob < 0 or prob > 1
147
+ for prob in [
148
+ self.attn_config["attn_pdrop"],
149
+ self.resid_pdrop,
150
+ self.emb_pdrop,
151
+ ]
152
+ )
153
+ ):
154
+ raise ValueError(
155
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
156
+ )
157
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
158
  raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
159
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
160
+ "torch",
161
+ "triton",
162
+ ]:
163
+ raise NotImplementedError(
164
+ "prefix_lm only implemented with torch and triton attention."
165
+ )
166
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
167
+ "torch",
168
+ "triton",
169
+ ]:
170
+ raise NotImplementedError(
171
+ "alibi only implemented with torch and triton attention."
172
+ )
173
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
174
+ "attn_impl"
175
+ ] not in ["torch", "triton"]:
176
+ raise NotImplementedError(
177
+ "attn_uses_sequence_id only implemented with torch and triton attention."
178
+ )
179
  if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
180
+ raise ValueError(
181
+ "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
182
+ )
183
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
184
+ raise ValueError(
185
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
186
+ )
187
+ if self.init_config.get("name", None) is None:
188
+ raise ValueError(
189
+ f"self.init_config={self.init_config!r} 'name' needs to be set."
190
+ )
191
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
192
+ raise ValueError(
193
+ f"Positional information must be provided to the model using either learned_pos_emb or alibi."
194
+ )
model/llava/model/mpt/hf_prefixlm_converter.py CHANGED
@@ -10,21 +10,37 @@ import math
10
  import warnings
11
  from types import MethodType
12
  from typing import Any, Dict, List, Optional, Tuple, Union
 
13
  import torch
14
- from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
15
- from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
16
- from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
 
 
 
 
17
  from transformers.models.bloom.modeling_bloom import logging
18
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
19
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
20
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
21
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
22
  from transformers.models.opt.modeling_opt import OPTForCausalLM
23
- from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
24
- from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
 
 
 
25
  logger = logging.get_logger(__name__)
26
- _SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
27
- CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
 
 
 
 
 
 
 
 
28
 
29
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
30
  """Converts a GPT-style Causal LM to a Prefix LM.
@@ -37,10 +53,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
37
 
38
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
39
  """
40
- if hasattr(model, '_prefix_lm_converted'):
41
  return model
42
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
43
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
 
 
44
 
45
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
46
  """Helper that gets a list of the model's attention modules.
@@ -56,7 +74,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
56
  blocks = model.transformer.h
57
  for block in blocks:
58
  if isinstance(model, GPTNeoForCausalLM):
59
- if block.attn.attention_type != 'global':
60
  continue
61
  attn_module = block.attn.attention
62
  elif isinstance(model, GPTNeoXForCausalLM):
@@ -65,17 +83,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
65
  attn_module = block.attn
66
  attn_modules.append(attn_module)
67
  return attn_modules
68
- setattr(model, '_original_forward', getattr(model, 'forward'))
69
- setattr(model, '_original_generate', getattr(model, 'generate'))
70
 
71
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """Wraps original forward to enable PrefixLM attention."""
73
 
74
  def call_og_forward():
75
  if isinstance(self, GPTNeoXForCausalLM):
76
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if bidirectional_mask is None:
80
  return call_og_forward()
81
  assert isinstance(bidirectional_mask, torch.Tensor)
@@ -83,14 +142,23 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
83
  (b, s) = bidirectional_mask.shape
84
  max_length = attn_modules[0].bias.shape[-1]
85
  if s > max_length:
86
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
 
 
 
87
  assert s <= max_length
88
  if s < max_length:
89
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
 
 
 
 
90
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
91
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
92
  for attn_module in attn_modules:
93
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
 
 
94
  output = call_og_forward()
95
  for attn_module in attn_modules:
96
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
@@ -105,11 +173,13 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
105
  for attn_module in attn_modules:
106
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
107
  return output
108
- setattr(model, 'forward', MethodType(forward, model))
109
- setattr(model, 'generate', MethodType(generate, model))
110
- setattr(model, '_prefix_lm_converted', True)
 
111
  return model
112
 
 
113
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
114
  """Converts a BLOOM Causal LM to a Prefix LM.
115
 
@@ -118,62 +188,137 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
118
 
119
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
120
  """
121
- if hasattr(model, '_prefix_lm_converted'):
122
  return model
123
  assert isinstance(model, BloomForCausalLM)
124
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
125
-
126
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
 
 
 
 
 
 
 
 
127
  combined_attention_mask = None
128
  device = attention_mask.device
129
  (_, src_length) = input_shape
130
  if src_length > 1:
131
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
 
 
 
 
132
  if bidirectional_mask is not None:
133
  assert attention_mask.shape == bidirectional_mask.shape
134
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
135
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
 
 
 
 
136
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
137
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
 
 
 
 
138
  return combined_attention_mask
139
 
140
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
 
 
 
 
 
 
 
141
  num_heads = self.config.n_head
142
  closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
143
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
144
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
 
 
 
 
 
 
145
  slopes = torch.pow(base, powers)
146
  if closest_power_of_2 != num_heads:
147
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
148
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
149
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
 
 
 
 
 
 
 
 
150
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
151
  qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
152
  ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
153
  diffs = qa - ka + key_length - query_length
154
  diffs = -diffs.abs()
155
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
156
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
 
 
 
 
157
  return alibi.to(dtype)
 
158
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
159
 
160
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
161
- if deprecated_arguments.pop('position_ids', False) is not False:
162
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if len(deprecated_arguments) > 0:
164
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
165
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
166
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
 
 
 
167
  use_cache = use_cache if use_cache is not None else self.config.use_cache
168
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
169
  if input_ids is not None and inputs_embeds is not None:
170
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
 
 
171
  elif input_ids is not None:
172
  (batch_size, seq_length) = input_ids.shape
173
  elif inputs_embeds is not None:
174
  (batch_size, seq_length, _) = inputs_embeds.shape
175
  else:
176
- raise ValueError('You have to specify either input_ids or inputs_embeds')
177
  if past_key_values is None:
178
  past_key_values = tuple([None] * len(self.h))
179
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
@@ -190,28 +335,62 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
190
  past_key_values_length = tmp.shape[2]
191
  seq_length_with_past = seq_length_with_past + past_key_values_length
192
  if attention_mask is None:
193
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
 
 
194
  else:
195
  attention_mask = attention_mask.to(hidden_states.device)
196
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
197
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
198
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
 
 
 
 
 
 
 
 
 
 
 
199
  if output_hidden_states:
200
  hst = (hidden_states,)
201
  all_hidden_states = all_hidden_states + hst
202
  if self.gradient_checkpointing and self.training:
203
  if use_cache:
204
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
 
 
205
  use_cache = False
206
 
207
  def create_custom_forward(module):
208
-
209
  def custom_forward(*inputs):
210
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
 
 
 
 
 
211
  return custom_forward
212
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
 
 
 
 
 
 
 
213
  else:
214
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
 
 
 
 
 
 
 
 
215
  hidden_states = outputs[0]
216
  if use_cache is True:
217
  presents = presents + (outputs[1],)
@@ -223,21 +402,77 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
223
  hst = (hidden_states,)
224
  all_hidden_states = all_hidden_states + hst
225
  if not return_dict:
226
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
227
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
228
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
229
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
230
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
232
 
233
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  """Replacement forward method for BloomCausalLM."""
235
- if deprecated_arguments.pop('position_ids', False) is not False:
236
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
 
 
 
 
237
  if len(deprecated_arguments) > 0:
238
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
239
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  hidden_states = transformer_outputs[0]
242
  lm_logits = self.lm_head(hidden_states)
243
  loss = None
@@ -246,13 +481,28 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
246
  shift_labels = labels[..., 1:].contiguous()
247
  (batch_size, seq_length, vocab_size) = shift_logits.shape
248
  loss_fct = CrossEntropyLoss()
249
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
 
 
 
250
  if not return_dict:
251
  output = (lm_logits,) + transformer_outputs[1:]
252
  return (loss,) + output if loss is not None else output
253
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
254
-
255
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
256
  if past:
257
  input_ids = input_ids[:, -1].unsqueeze(-1)
258
  bidirectional_mask = None
@@ -260,12 +510,24 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
260
  past = self._convert_to_bloom_cache(past)
261
  else:
262
  bidirectional_mask = torch.ones_like(input_ids)
263
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
264
- setattr(model, 'forward', MethodType(forward, model))
265
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
266
- setattr(model, '_prefix_lm_converted', True)
 
 
 
 
 
 
 
 
 
 
 
267
  return model
268
 
 
269
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
270
  """Converts an OPT Causal LM to a Prefix LM.
271
 
@@ -274,36 +536,89 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
274
 
275
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
276
  """
277
- if hasattr(model, '_prefix_lm_converted'):
278
  return model
279
  assert isinstance(model, OPTForCausalLM)
280
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
281
- setattr(model, '_original_forward', getattr(model, 'forward'))
282
- setattr(model, '_original_generate', getattr(model, 'generate'))
 
 
283
  model.model.decoder.bidirectional_mask = None
284
 
285
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 
 
286
  combined_attention_mask = None
287
  if input_shape[-1] > 1:
288
- if self.bidirectional_mask == 'g':
289
  (bsz, src_length) = input_shape
290
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 
 
 
 
291
  else:
292
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
 
 
 
 
293
  if self.bidirectional_mask is not None:
294
  assert attention_mask.shape == self.bidirectional_mask.shape
295
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
296
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
 
 
 
 
 
 
297
  if attention_mask is not None:
298
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
299
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 
 
 
 
 
 
300
  return combined_attention_mask
301
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
302
-
303
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def call_og_forward():
306
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
 
 
 
 
 
 
 
 
 
 
 
 
307
  if bidirectional_mask is None:
308
  return call_og_forward()
309
  self.model.decoder.bidirectional_mask = bidirectional_mask
@@ -317,7 +632,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
317
 
318
  def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
319
  """Wraps original generate to enable PrefixLM-style attention."""
320
- self.model.decoder.bidirectional_mask = 'g'
321
  try:
322
  output = self._original_generate(*args, **kwargs)
323
  except:
@@ -325,12 +640,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
325
  raise
326
  self.model.decoder.bidirectional_mask = None
327
  return output
328
- setattr(model, 'forward', MethodType(forward, model))
329
- setattr(model, 'generate', MethodType(generate, model))
330
- setattr(model, '_prefix_lm_converted', True)
 
331
  return model
 
 
332
  _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
333
- CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
 
 
 
 
 
 
 
 
334
 
335
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
336
  """Converts a HuggingFace Causal LM to a Prefix LM.
@@ -396,7 +722,12 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
396
  elif isinstance(model, OPTForCausalLM):
397
  return _convert_opt_causal_lm_to_prefix_lm(model)
398
  else:
399
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
 
 
 
 
 
400
 
401
  def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
402
  """Attempts to add bidirectional_mask to batch if missing.
@@ -404,12 +735,16 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
404
  Raises:
405
  KeyError if bidirectional_mask is missing and can't be inferred
406
  """
407
- if 'bidirectional_mask' not in batch:
408
- if batch.get('mode', None) == 'icl_task':
409
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
410
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
411
- batch['bidirectional_mask'][i, continuation_indices] = 0
412
- elif 'labels' in batch and 'attention_mask' in batch:
413
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
 
 
414
  else:
415
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
 
 
 
10
  import warnings
11
  from types import MethodType
12
  from typing import Any, Dict, List, Optional, Tuple, Union
13
+
14
  import torch
15
+ from transformers.models.bloom.modeling_bloom import (
16
+ BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
17
+ CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
18
+ from transformers.models.bloom.modeling_bloom import \
19
+ _expand_mask as _expand_mask_bloom
20
+ from transformers.models.bloom.modeling_bloom import \
21
+ _make_causal_mask as _make_causal_mask_bloom
22
  from transformers.models.bloom.modeling_bloom import logging
23
  from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
24
  from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
25
  from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
26
  from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
27
  from transformers.models.opt.modeling_opt import OPTForCausalLM
28
+ from transformers.models.opt.modeling_opt import \
29
+ _expand_mask as _expand_mask_opt
30
+ from transformers.models.opt.modeling_opt import \
31
+ _make_causal_mask as _make_causal_mask_opt
32
+
33
  logger = logging.get_logger(__name__)
34
+ _SUPPORTED_GPT_MODELS = (
35
+ GPT2LMHeadModel,
36
+ GPTJForCausalLM,
37
+ GPTNeoForCausalLM,
38
+ GPTNeoXForCausalLM,
39
+ )
40
+ CAUSAL_GPT_TYPES = Union[
41
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
42
+ ]
43
+
44
 
45
  def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
46
  """Converts a GPT-style Causal LM to a Prefix LM.
 
53
 
54
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
55
  """
56
+ if hasattr(model, "_prefix_lm_converted"):
57
  return model
58
  assert isinstance(model, _SUPPORTED_GPT_MODELS)
59
+ assert (
60
+ model.config.add_cross_attention == False
61
+ ), "Only supports GPT-style decoder-only models"
62
 
63
  def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
64
  """Helper that gets a list of the model's attention modules.
 
74
  blocks = model.transformer.h
75
  for block in blocks:
76
  if isinstance(model, GPTNeoForCausalLM):
77
+ if block.attn.attention_type != "global":
78
  continue
79
  attn_module = block.attn.attention
80
  elif isinstance(model, GPTNeoXForCausalLM):
 
83
  attn_module = block.attn
84
  attn_modules.append(attn_module)
85
  return attn_modules
 
 
86
 
87
+ setattr(model, "_original_forward", getattr(model, "forward"))
88
+ setattr(model, "_original_generate", getattr(model, "generate"))
89
+
90
+ def forward(
91
+ self: CAUSAL_GPT_TYPES,
92
+ input_ids: Optional[torch.LongTensor] = None,
93
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
94
+ attention_mask: Optional[torch.FloatTensor] = None,
95
+ bidirectional_mask: Optional[torch.Tensor] = None,
96
+ token_type_ids: Optional[torch.LongTensor] = None,
97
+ position_ids: Optional[torch.LongTensor] = None,
98
+ head_mask: Optional[torch.FloatTensor] = None,
99
+ inputs_embeds: Optional[torch.FloatTensor] = None,
100
+ labels: Optional[torch.LongTensor] = None,
101
+ use_cache: Optional[bool] = None,
102
+ output_attentions: Optional[bool] = None,
103
+ output_hidden_states: Optional[bool] = None,
104
+ return_dict: Optional[bool] = None,
105
+ ):
106
  """Wraps original forward to enable PrefixLM attention."""
107
 
108
  def call_og_forward():
109
  if isinstance(self, GPTNeoXForCausalLM):
110
+ return self._original_forward(
111
+ input_ids=input_ids,
112
+ past_key_values=past_key_values,
113
+ attention_mask=attention_mask,
114
+ head_mask=head_mask,
115
+ inputs_embeds=inputs_embeds,
116
+ labels=labels,
117
+ use_cache=use_cache,
118
+ output_attentions=output_attentions,
119
+ output_hidden_states=output_hidden_states,
120
+ return_dict=return_dict,
121
+ )
122
  else:
123
+ return self._original_forward(
124
+ input_ids=input_ids,
125
+ past_key_values=past_key_values,
126
+ attention_mask=attention_mask,
127
+ token_type_ids=token_type_ids,
128
+ position_ids=position_ids,
129
+ head_mask=head_mask,
130
+ inputs_embeds=inputs_embeds,
131
+ labels=labels,
132
+ use_cache=use_cache,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ )
137
+
138
  if bidirectional_mask is None:
139
  return call_og_forward()
140
  assert isinstance(bidirectional_mask, torch.Tensor)
 
142
  (b, s) = bidirectional_mask.shape
143
  max_length = attn_modules[0].bias.shape[-1]
144
  if s > max_length:
145
+ raise ValueError(
146
+ f"bidirectional_mask sequence length (={s}) exceeds the "
147
+ + f"max length allowed by the model ({max_length})."
148
+ )
149
  assert s <= max_length
150
  if s < max_length:
151
+ pad = torch.zeros(
152
+ (int(b), int(max_length - s)),
153
+ dtype=bidirectional_mask.dtype,
154
+ device=bidirectional_mask.device,
155
+ )
156
  bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
157
  bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
158
  for attn_module in attn_modules:
159
+ attn_module.bias.data = torch.logical_or(
160
+ attn_module.bias.data, bidirectional
161
+ )
162
  output = call_og_forward()
163
  for attn_module in attn_modules:
164
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
 
173
  for attn_module in attn_modules:
174
  attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
175
  return output
176
+
177
+ setattr(model, "forward", MethodType(forward, model))
178
+ setattr(model, "generate", MethodType(generate, model))
179
+ setattr(model, "_prefix_lm_converted", True)
180
  return model
181
 
182
+
183
  def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
184
  """Converts a BLOOM Causal LM to a Prefix LM.
185
 
 
188
 
189
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
190
  """
191
+ if hasattr(model, "_prefix_lm_converted"):
192
  return model
193
  assert isinstance(model, BloomForCausalLM)
194
+ assert (
195
+ model.config.add_cross_attention == False
196
+ ), "Only supports BLOOM decoder-only models"
197
+
198
+ def _prepare_attn_mask(
199
+ self: BloomModel,
200
+ attention_mask: torch.Tensor,
201
+ bidirectional_mask: Optional[torch.Tensor],
202
+ input_shape: Tuple[int, int],
203
+ past_key_values_length: int,
204
+ ) -> torch.BoolTensor:
205
  combined_attention_mask = None
206
  device = attention_mask.device
207
  (_, src_length) = input_shape
208
  if src_length > 1:
209
+ combined_attention_mask = _make_causal_mask_bloom(
210
+ input_shape,
211
+ device=device,
212
+ past_key_values_length=past_key_values_length,
213
+ )
214
  if bidirectional_mask is not None:
215
  assert attention_mask.shape == bidirectional_mask.shape
216
+ expanded_bidirectional_mask = _expand_mask_bloom(
217
+ bidirectional_mask, tgt_length=src_length
218
+ )
219
+ combined_attention_mask = torch.logical_and(
220
+ combined_attention_mask, expanded_bidirectional_mask
221
+ )
222
  expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
223
+ combined_attention_mask = (
224
+ expanded_attn_mask
225
+ if combined_attention_mask is None
226
+ else expanded_attn_mask | combined_attention_mask
227
+ )
228
  return combined_attention_mask
229
 
230
+ def _build_alibi_tensor(
231
+ self: BloomModel,
232
+ batch_size: int,
233
+ query_length: int,
234
+ key_length: int,
235
+ dtype: torch.dtype,
236
+ device: torch.device,
237
+ ) -> torch.Tensor:
238
  num_heads = self.config.n_head
239
  closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
240
+ base = torch.tensor(
241
+ 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))),
242
+ device=device,
243
+ dtype=torch.float32,
244
+ )
245
+ powers = torch.arange(
246
+ 1, 1 + closest_power_of_2, device=device, dtype=torch.int32
247
+ )
248
  slopes = torch.pow(base, powers)
249
  if closest_power_of_2 != num_heads:
250
+ extra_base = torch.tensor(
251
+ 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))),
252
+ device=device,
253
+ dtype=torch.float32,
254
+ )
255
+ num_remaining_heads = min(
256
+ closest_power_of_2, num_heads - closest_power_of_2
257
+ )
258
+ extra_powers = torch.arange(
259
+ 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
260
+ )
261
  slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
262
  qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
263
  ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
264
  diffs = qa - ka + key_length - query_length
265
  diffs = -diffs.abs()
266
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
267
+ 1, 1, query_length, key_length
268
+ )
269
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(
270
+ -1, query_length, key_length
271
+ )
272
  return alibi.to(dtype)
273
+
274
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
275
 
276
+ def forward(
277
+ self: BloomModel,
278
+ input_ids: Optional[torch.LongTensor] = None,
279
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
280
+ attention_mask: Optional[torch.Tensor] = None,
281
+ bidirectional_mask: Optional[torch.Tensor] = None,
282
+ head_mask: Optional[torch.LongTensor] = None,
283
+ inputs_embeds: Optional[torch.LongTensor] = None,
284
+ use_cache: Optional[bool] = None,
285
+ output_attentions: Optional[bool] = None,
286
+ output_hidden_states: Optional[bool] = None,
287
+ return_dict: Optional[bool] = None,
288
+ **deprecated_arguments,
289
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
290
+ if deprecated_arguments.pop("position_ids", False) is not False:
291
+ warnings.warn(
292
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
293
+ + "You can safely ignore passing `position_ids`.",
294
+ FutureWarning,
295
+ )
296
  if len(deprecated_arguments) > 0:
297
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
298
+ output_attentions = (
299
+ output_attentions
300
+ if output_attentions is not None
301
+ else self.config.output_attentions
302
+ )
303
+ output_hidden_states = (
304
+ output_hidden_states
305
+ if output_hidden_states is not None
306
+ else self.config.output_hidden_states
307
+ )
308
  use_cache = use_cache if use_cache is not None else self.config.use_cache
309
+ return_dict = (
310
+ return_dict if return_dict is not None else self.config.use_return_dict
311
+ )
312
  if input_ids is not None and inputs_embeds is not None:
313
+ raise ValueError(
314
+ "You cannot specify both input_ids and inputs_embeds at the same time"
315
+ )
316
  elif input_ids is not None:
317
  (batch_size, seq_length) = input_ids.shape
318
  elif inputs_embeds is not None:
319
  (batch_size, seq_length, _) = inputs_embeds.shape
320
  else:
321
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
322
  if past_key_values is None:
323
  past_key_values = tuple([None] * len(self.h))
324
  head_mask = self.get_head_mask(head_mask, self.config.n_layer)
 
335
  past_key_values_length = tmp.shape[2]
336
  seq_length_with_past = seq_length_with_past + past_key_values_length
337
  if attention_mask is None:
338
+ attention_mask = torch.ones(
339
+ (batch_size, seq_length_with_past), device=hidden_states.device
340
+ )
341
  else:
342
  attention_mask = attention_mask.to(hidden_states.device)
343
+ alibi = self._build_alibi_tensor(
344
+ batch_size=batch_size,
345
+ query_length=seq_length,
346
+ key_length=seq_length_with_past,
347
+ dtype=hidden_states.dtype,
348
+ device=hidden_states.device,
349
+ )
350
+ causal_mask = self._prepare_attn_mask(
351
+ attention_mask,
352
+ bidirectional_mask,
353
+ input_shape=(batch_size, seq_length),
354
+ past_key_values_length=past_key_values_length,
355
+ )
356
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
357
  if output_hidden_states:
358
  hst = (hidden_states,)
359
  all_hidden_states = all_hidden_states + hst
360
  if self.gradient_checkpointing and self.training:
361
  if use_cache:
362
+ logger.warning(
363
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
364
+ )
365
  use_cache = False
366
 
367
  def create_custom_forward(module):
 
368
  def custom_forward(*inputs):
369
+ return module(
370
+ *inputs,
371
+ use_cache=use_cache,
372
+ output_attentions=output_attentions,
373
+ )
374
+
375
  return custom_forward
376
+
377
+ outputs = torch.utils.checkpoint.checkpoint(
378
+ create_custom_forward(block),
379
+ hidden_states,
380
+ alibi,
381
+ causal_mask,
382
+ head_mask[i],
383
+ )
384
  else:
385
+ outputs = block(
386
+ hidden_states,
387
+ layer_past=layer_past,
388
+ attention_mask=causal_mask,
389
+ head_mask=head_mask[i],
390
+ use_cache=use_cache,
391
+ output_attentions=output_attentions,
392
+ alibi=alibi,
393
+ )
394
  hidden_states = outputs[0]
395
  if use_cache is True:
396
  presents = presents + (outputs[1],)
 
402
  hst = (hidden_states,)
403
  all_hidden_states = all_hidden_states + hst
404
  if not return_dict:
405
+ return tuple(
406
+ (
407
+ v
408
+ for v in [
409
+ hidden_states,
410
+ presents,
411
+ all_hidden_states,
412
+ all_self_attentions,
413
+ ]
414
+ if v is not None
415
+ )
416
+ )
417
+ return BaseModelOutputWithPastAndCrossAttentions(
418
+ last_hidden_state=hidden_states,
419
+ past_key_values=presents,
420
+ hidden_states=all_hidden_states,
421
+ attentions=all_self_attentions,
422
+ )
423
+
424
+ setattr(
425
+ model.transformer,
426
+ "_prepare_attn_mask",
427
+ MethodType(_prepare_attn_mask, model.transformer),
428
+ )
429
+ setattr(
430
+ model.transformer,
431
+ "_build_alibi_tensor",
432
+ MethodType(_build_alibi_tensor, model.transformer),
433
+ )
434
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
435
  KeyValueT = Tuple[torch.Tensor, torch.Tensor]
436
 
437
+ def forward(
438
+ self: BloomForCausalLM,
439
+ input_ids: Optional[torch.LongTensor] = None,
440
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ bidirectional_mask: Optional[torch.Tensor] = None,
443
+ head_mask: Optional[torch.Tensor] = None,
444
+ inputs_embeds: Optional[torch.Tensor] = None,
445
+ labels: Optional[torch.Tensor] = None,
446
+ use_cache: Optional[bool] = None,
447
+ output_attentions: Optional[bool] = None,
448
+ output_hidden_states: Optional[bool] = None,
449
+ return_dict: Optional[bool] = None,
450
+ **deprecated_arguments,
451
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
452
  """Replacement forward method for BloomCausalLM."""
453
+ if deprecated_arguments.pop("position_ids", False) is not False:
454
+ warnings.warn(
455
+ "`position_ids` have no functionality in BLOOM and will be removed "
456
+ + "in v5.0.0. You can safely ignore passing `position_ids`.",
457
+ FutureWarning,
458
+ )
459
  if len(deprecated_arguments) > 0:
460
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ transformer_outputs = self.transformer(
465
+ input_ids,
466
+ past_key_values=past_key_values,
467
+ attention_mask=attention_mask,
468
+ bidirectional_mask=bidirectional_mask,
469
+ head_mask=head_mask,
470
+ inputs_embeds=inputs_embeds,
471
+ use_cache=use_cache,
472
+ output_attentions=output_attentions,
473
+ output_hidden_states=output_hidden_states,
474
+ return_dict=return_dict,
475
+ )
476
  hidden_states = transformer_outputs[0]
477
  lm_logits = self.lm_head(hidden_states)
478
  loss = None
 
481
  shift_labels = labels[..., 1:].contiguous()
482
  (batch_size, seq_length, vocab_size) = shift_logits.shape
483
  loss_fct = CrossEntropyLoss()
484
+ loss = loss_fct(
485
+ shift_logits.view(batch_size * seq_length, vocab_size),
486
+ shift_labels.view(batch_size * seq_length),
487
+ )
488
  if not return_dict:
489
  output = (lm_logits,) + transformer_outputs[1:]
490
  return (loss,) + output if loss is not None else output
491
+ return CausalLMOutputWithCrossAttentions(
492
+ loss=loss,
493
+ logits=lm_logits,
494
+ past_key_values=transformer_outputs.past_key_values,
495
+ hidden_states=transformer_outputs.hidden_states,
496
+ attentions=transformer_outputs.attentions,
497
+ )
498
+
499
+ def prepare_inputs_for_generation(
500
+ self: BloomForCausalLM,
501
+ input_ids: torch.LongTensor,
502
+ past: Optional[torch.Tensor] = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ **kwargs,
505
+ ) -> dict:
506
  if past:
507
  input_ids = input_ids[:, -1].unsqueeze(-1)
508
  bidirectional_mask = None
 
510
  past = self._convert_to_bloom_cache(past)
511
  else:
512
  bidirectional_mask = torch.ones_like(input_ids)
513
+ return {
514
+ "input_ids": input_ids,
515
+ "past_key_values": past,
516
+ "use_cache": True,
517
+ "attention_mask": attention_mask,
518
+ "bidirectional_mask": bidirectional_mask,
519
+ }
520
+
521
+ setattr(model, "forward", MethodType(forward, model))
522
+ setattr(
523
+ model,
524
+ "prepare_inputs_for_generation",
525
+ MethodType(prepare_inputs_for_generation, model),
526
+ )
527
+ setattr(model, "_prefix_lm_converted", True)
528
  return model
529
 
530
+
531
  def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
532
  """Converts an OPT Causal LM to a Prefix LM.
533
 
 
536
 
537
  See `convert_hf_causal_lm_to_prefix_lm` for more details.
538
  """
539
+ if hasattr(model, "_prefix_lm_converted"):
540
  return model
541
  assert isinstance(model, OPTForCausalLM)
542
+ assert (
543
+ model.config.add_cross_attention == False
544
+ ), "Only supports OPT decoder-only models"
545
+ setattr(model, "_original_forward", getattr(model, "forward"))
546
+ setattr(model, "_original_generate", getattr(model, "generate"))
547
  model.model.decoder.bidirectional_mask = None
548
 
549
+ def _prepare_decoder_attention_mask(
550
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
551
+ ):
552
  combined_attention_mask = None
553
  if input_shape[-1] > 1:
554
+ if self.bidirectional_mask == "g":
555
  (bsz, src_length) = input_shape
556
+ combined_attention_mask = torch.zeros(
557
+ (bsz, 1, src_length, src_length + past_key_values_length),
558
+ dtype=inputs_embeds.dtype,
559
+ device=inputs_embeds.device,
560
+ )
561
  else:
562
+ combined_attention_mask = _make_causal_mask_opt(
563
+ input_shape,
564
+ inputs_embeds.dtype,
565
+ past_key_values_length=past_key_values_length,
566
+ ).to(inputs_embeds.device)
567
  if self.bidirectional_mask is not None:
568
  assert attention_mask.shape == self.bidirectional_mask.shape
569
+ expanded_bidirectional_mask = _expand_mask_opt(
570
+ self.bidirectional_mask,
571
+ inputs_embeds.dtype,
572
+ tgt_len=input_shape[-1],
573
+ ).to(inputs_embeds.device)
574
+ combined_attention_mask = torch.maximum(
575
+ expanded_bidirectional_mask, combined_attention_mask
576
+ )
577
  if attention_mask is not None:
578
+ expanded_attn_mask = _expand_mask_opt(
579
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
580
+ ).to(inputs_embeds.device)
581
+ combined_attention_mask = (
582
+ expanded_attn_mask
583
+ if combined_attention_mask is None
584
+ else expanded_attn_mask + combined_attention_mask
585
+ )
586
  return combined_attention_mask
 
 
 
587
 
588
+ setattr(
589
+ model.model.decoder,
590
+ "_prepare_decoder_attention_mask",
591
+ MethodType(_prepare_decoder_attention_mask, model.model.decoder),
592
+ )
593
+
594
+ def forward(
595
+ self: OPTForCausalLM,
596
+ input_ids: Optional[torch.LongTensor] = None,
597
+ attention_mask: Optional[torch.Tensor] = None,
598
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
599
+ head_mask: Optional[torch.Tensor] = None,
600
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
601
+ inputs_embeds: Optional[torch.FloatTensor] = None,
602
+ labels: Optional[torch.LongTensor] = None,
603
+ use_cache: Optional[bool] = None,
604
+ output_attentions: Optional[bool] = None,
605
+ output_hidden_states: Optional[bool] = None,
606
+ return_dict: Optional[bool] = None,
607
+ ):
608
  def call_og_forward():
609
+ return self._original_forward(
610
+ input_ids=input_ids,
611
+ attention_mask=attention_mask,
612
+ head_mask=head_mask,
613
+ past_key_values=past_key_values,
614
+ inputs_embeds=inputs_embeds,
615
+ labels=labels,
616
+ use_cache=use_cache,
617
+ output_attentions=output_attentions,
618
+ output_hidden_states=output_hidden_states,
619
+ return_dict=return_dict,
620
+ )
621
+
622
  if bidirectional_mask is None:
623
  return call_og_forward()
624
  self.model.decoder.bidirectional_mask = bidirectional_mask
 
632
 
633
  def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
634
  """Wraps original generate to enable PrefixLM-style attention."""
635
+ self.model.decoder.bidirectional_mask = "g"
636
  try:
637
  output = self._original_generate(*args, **kwargs)
638
  except:
 
640
  raise
641
  self.model.decoder.bidirectional_mask = None
642
  return output
643
+
644
+ setattr(model, "forward", MethodType(forward, model))
645
+ setattr(model, "generate", MethodType(generate, model))
646
+ setattr(model, "_prefix_lm_converted", True)
647
  return model
648
+
649
+
650
  _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
651
+ CAUSAL_LM_TYPES = Union[
652
+ GPT2LMHeadModel,
653
+ GPTJForCausalLM,
654
+ GPTNeoForCausalLM,
655
+ GPTNeoXForCausalLM,
656
+ BloomForCausalLM,
657
+ OPTForCausalLM,
658
+ ]
659
+
660
 
661
  def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
662
  """Converts a HuggingFace Causal LM to a Prefix LM.
 
722
  elif isinstance(model, OPTForCausalLM):
723
  return _convert_opt_causal_lm_to_prefix_lm(model)
724
  else:
725
+ raise TypeError(
726
+ f"Cannot convert model to Prefix LM. "
727
+ + f"Model does not belong to set of supported HF models:"
728
+ + f"\n{_SUPPORTED_HF_MODELS}"
729
+ )
730
+
731
 
732
  def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
733
  """Attempts to add bidirectional_mask to batch if missing.
 
735
  Raises:
736
  KeyError if bidirectional_mask is missing and can't be inferred
737
  """
738
+ if "bidirectional_mask" not in batch:
739
+ if batch.get("mode", None) == "icl_task":
740
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
741
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
742
+ batch["bidirectional_mask"][i, continuation_indices] = 0
743
+ elif "labels" in batch and "attention_mask" in batch:
744
+ batch["bidirectional_mask"] = torch.logical_and(
745
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
746
+ ).type_as(batch["attention_mask"])
747
  else:
748
+ raise KeyError(
749
+ "No bidirectional_mask in batch and not sure how to construct one."
750
+ )
model/llava/model/mpt/meta_init_context.py CHANGED
@@ -1,9 +1,11 @@
1
  from contextlib import contextmanager
 
2
  import torch
3
  import torch.nn as nn
4
 
 
5
  @contextmanager
6
- def init_empty_weights(include_buffers: bool=False):
7
  """Meta initialization context manager.
8
 
9
  A context manager under which models are initialized with all parameters
@@ -30,11 +32,12 @@ def init_empty_weights(include_buffers: bool=False):
30
 
31
  </Tip>
32
  """
33
- with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34
  yield f
35
 
 
36
  @contextmanager
37
- def init_on_device(device: torch.device, include_buffers: bool=False):
38
  """Device initialization context manager.
39
 
40
  A context manager under which models are initialized with all parameters
@@ -62,33 +65,47 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
62
  if param is not None:
63
  param_cls = type(module._parameters[name])
64
  kwargs = module._parameters[name].__dict__
65
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
 
 
66
 
67
  def register_empty_buffer(module, name, buffer):
68
  old_register_buffer(module, name, buffer)
69
  if buffer is not None:
70
  module._buffers[name] = module._buffers[name].to(device)
 
71
  if include_buffers:
72
- tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
 
 
 
73
  else:
74
  tensor_constructors_to_patch = {}
75
 
76
  def patch_tensor_constructor(fn):
77
-
78
  def wrapper(*args, **kwargs):
79
- kwargs['device'] = device
80
  return fn(*args, **kwargs)
 
81
  return wrapper
 
82
  try:
83
  nn.Module.register_parameter = register_empty_parameter
84
  if include_buffers:
85
  nn.Module.register_buffer = register_empty_buffer
86
  for torch_function_name in tensor_constructors_to_patch.keys():
87
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
 
 
 
 
88
  yield
89
  finally:
90
  nn.Module.register_parameter = old_register_parameter
91
  if include_buffers:
92
  nn.Module.register_buffer = old_register_buffer
93
- for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94
- setattr(torch, torch_function_name, old_torch_function)
 
 
 
 
1
  from contextlib import contextmanager
2
+
3
  import torch
4
  import torch.nn as nn
5
 
6
+
7
  @contextmanager
8
+ def init_empty_weights(include_buffers: bool = False):
9
  """Meta initialization context manager.
10
 
11
  A context manager under which models are initialized with all parameters
 
32
 
33
  </Tip>
34
  """
35
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
36
  yield f
37
 
38
+
39
  @contextmanager
40
+ def init_on_device(device: torch.device, include_buffers: bool = False):
41
  """Device initialization context manager.
42
 
43
  A context manager under which models are initialized with all parameters
 
65
  if param is not None:
66
  param_cls = type(module._parameters[name])
67
  kwargs = module._parameters[name].__dict__
68
+ module._parameters[name] = param_cls(
69
+ module._parameters[name].to(device), **kwargs
70
+ )
71
 
72
  def register_empty_buffer(module, name, buffer):
73
  old_register_buffer(module, name, buffer)
74
  if buffer is not None:
75
  module._buffers[name] = module._buffers[name].to(device)
76
+
77
  if include_buffers:
78
+ tensor_constructors_to_patch = {
79
+ torch_function_name: getattr(torch, torch_function_name)
80
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
81
+ }
82
  else:
83
  tensor_constructors_to_patch = {}
84
 
85
  def patch_tensor_constructor(fn):
 
86
  def wrapper(*args, **kwargs):
87
+ kwargs["device"] = device
88
  return fn(*args, **kwargs)
89
+
90
  return wrapper
91
+
92
  try:
93
  nn.Module.register_parameter = register_empty_parameter
94
  if include_buffers:
95
  nn.Module.register_buffer = register_empty_buffer
96
  for torch_function_name in tensor_constructors_to_patch.keys():
97
+ setattr(
98
+ torch,
99
+ torch_function_name,
100
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
101
+ )
102
  yield
103
  finally:
104
  nn.Module.register_parameter = old_register_parameter
105
  if include_buffers:
106
  nn.Module.register_buffer = old_register_buffer
107
+ for (
108
+ torch_function_name,
109
+ old_torch_function,
110
+ ) in tensor_constructors_to_patch.items():
111
+ setattr(torch, torch_function_name, old_torch_function)
model/llava/model/mpt/modeling_mpt.py CHANGED
@@ -5,68 +5,95 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
5
  import math
6
  import warnings
7
  from typing import List, Optional, Tuple, Union
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
 
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
- from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
- from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
- from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
 
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  from transformers.utils import logging
 
24
  logger = logging.get_logger(__name__)
25
 
 
26
  class MPTPreTrainedModel(PreTrainedModel):
27
  config_class = MPTConfig
28
- base_model_prefix = 'model'
29
 
30
- class MPTModel(MPTPreTrainedModel):
31
 
 
32
  def __init__(self, config: MPTConfig):
33
  config._validate_config()
34
  super().__init__(config)
35
- self.attn_impl = config.attn_config['attn_impl']
36
- self.prefix_lm = config.attn_config['prefix_lm']
37
- self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
38
- self.alibi = config.attn_config['alibi']
39
- self.alibi_bias_max = config.attn_config['alibi_bias_max']
40
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
41
- norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
42
- raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
 
 
43
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
44
  self.embedding_fraction = config.embedding_fraction
45
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
 
 
46
  if not self.alibi:
47
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
 
 
48
  self.emb_drop = nn.Dropout(config.emb_pdrop)
49
- self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
 
 
 
 
 
50
  self.norm_f = norm_class(config.d_model, device=config.init_device)
51
- if config.init_device != 'meta':
52
  self.apply(self.param_init_fn)
53
  self.is_causal = not self.prefix_lm
54
  self._attn_bias_initialized = False
55
  self.attn_bias = None
56
- self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
 
 
 
 
 
 
 
 
57
  if config.no_bias:
58
  for module in self.modules():
59
- if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
60
  if config.verbose:
61
- warnings.warn(f'Removing bias ({module.bias}) from {module}.')
62
- module.register_parameter('bias', None)
63
  if config.verbose and config.verbose > 2:
64
  print(self)
65
- if 'verbose' not in self.config.init_config:
66
- self.config.init_config['verbose'] = self.config.verbose
67
- if self.config.init_config['verbose'] > 1:
68
- init_fn_name = self.config.init_config['name']
69
- warnings.warn(f'Using {init_fn_name} initialization.')
70
  self.gradient_checkpointing = False
71
 
72
  def get_input_embeddings(self):
@@ -76,13 +103,30 @@ class MPTModel(MPTPreTrainedModel):
76
  self.wte = value
77
 
78
  @torch.no_grad()
79
- def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
 
 
 
 
 
 
 
80
  if not self._attn_bias_initialized:
81
  if self.attn_bias_shape:
82
- self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
83
- self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
 
 
 
 
 
 
 
 
 
 
84
  self._attn_bias_initialized = True
85
- if self.attn_impl == 'flash':
86
  return (self.attn_bias, attention_mask)
87
  if self.attn_bias is not None:
88
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
@@ -101,38 +145,71 @@ class MPTModel(MPTPreTrainedModel):
101
  else:
102
  attn_bias = attn_bias[:, :, :, -s_k:]
103
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
104
- raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
 
 
 
105
  min_val = torch.finfo(attn_bias.dtype).min
106
- attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
 
 
107
  return (attn_bias, None)
108
 
109
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
110
  (s_k, s_q) = attn_bias.shape[-2:]
111
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
112
- raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
 
 
 
 
113
  seq_len = prefix_mask.shape[-1]
114
  if seq_len > self.config.max_seq_len:
115
- raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
 
116
  attn_bias = attn_bias[..., :seq_len, :seq_len]
117
- causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
 
 
118
  prefix = prefix_mask.view(-1, 1, 1, seq_len)
119
  cannot_attend = ~torch.logical_or(causal, prefix.bool())
120
  min_val = torch.finfo(attn_bias.dtype).min
121
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
122
  return attn_bias
123
 
124
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
 
 
125
  seq_len = sequence_id.shape[-1]
126
  if seq_len > self.config.max_seq_len:
127
- raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
 
128
  attn_bias = attn_bias[..., :seq_len, :seq_len]
129
- cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
 
 
130
  min_val = torch.finfo(attn_bias.dtype).min
131
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
132
  return attn_bias
133
 
134
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None):
135
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  use_cache = use_cache if use_cache is not None else self.config.use_cache
137
 
138
  if self.gradient_checkpointing and self.training:
@@ -146,21 +223,41 @@ class MPTModel(MPTPreTrainedModel):
146
  if prefix_mask is not None:
147
  prefix_mask = prefix_mask.bool()
148
  if not return_dict:
149
- raise NotImplementedError('return_dict False is not implemented yet for MPT')
 
 
150
  if output_attentions:
151
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
152
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
153
- raise NotImplementedError('MPT does not support training with left padding.')
 
 
 
 
 
 
 
 
154
  if self.prefix_lm and prefix_mask is None:
155
- raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
156
  if self.training:
157
  if self.attn_uses_sequence_id and sequence_id is None:
158
- raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
 
 
 
159
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
160
- warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
161
  if input_ids is not None:
162
  S = input_ids.size(1)
163
- assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
 
 
164
  tok_emb = self.wte(input_ids)
165
  else:
166
  assert tok_emb is not None
@@ -171,45 +268,85 @@ class MPTModel(MPTPreTrainedModel):
171
  past_position = 0
172
  if past_key_values is not None:
173
  if len(past_key_values) != self.config.n_layers:
174
- raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
 
 
 
175
  past_position = past_key_values[0][0].size(1)
176
  if S + past_position > self.config.max_seq_len:
177
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
178
- pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
 
 
 
 
 
 
179
  if attention_mask is not None:
180
- pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
 
 
 
 
 
 
181
  pos_emb = self.wpe(pos)
182
  x = tok_emb + pos_emb
183
  if self.embedding_fraction == 1:
184
  x = self.emb_drop(x)
185
  else:
186
- x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
 
 
187
  assert isinstance(self.emb_drop, nn.Module)
188
  x = self.emb_drop(x_shrunk)
189
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
 
 
 
 
 
 
190
  if use_cache and past_key_values is None:
191
  past_key_values = [() for _ in range(self.config.n_layers)]
192
  all_hidden_states = () if output_hidden_states else None
193
- for (b_idx, block) in enumerate(self.blocks):
194
  if output_hidden_states:
195
  assert all_hidden_states is not None
196
  all_hidden_states = all_hidden_states + (x,)
197
- past_key_value = past_key_values[b_idx] if past_key_values is not None else None
 
 
198
  if self.gradient_checkpointing and self.training:
199
  (x, past_key_value) = torch.utils.checkpoint.checkpoint(
200
- block,
201
- x, past_key_value, attn_bias, attention_mask, self.is_causal
202
  )
203
  else:
204
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
205
  if past_key_values is not None:
206
  past_key_values[b_idx] = past_key_value
207
  x = self.norm_f(x)
208
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
 
209
 
210
  def param_init_fn(self, module):
211
- init_fn_name = self.config.init_config['name']
212
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
 
 
 
 
 
213
 
214
  def fsdp_wrap_fn(self, module):
215
  return isinstance(module, MPTBlock)
@@ -217,21 +354,23 @@ class MPTModel(MPTPreTrainedModel):
217
  def activation_checkpointing_fn(self, module):
218
  return isinstance(module, MPTBlock)
219
 
220
- class MPTForCausalLM(MPTPreTrainedModel):
221
 
 
222
  def __init__(self, config: MPTConfig):
223
  super().__init__(config)
224
  if not config.tie_word_embeddings:
225
- raise ValueError('MPTForCausalLM only supports tied word embeddings')
226
  self.transformer = MPTModel(config)
227
  self.logit_scale = None
228
  if config.logit_scale is not None:
229
  logit_scale = config.logit_scale
230
  if isinstance(logit_scale, str):
231
- if logit_scale == 'inv_sqrt_d_model':
232
  logit_scale = 1 / math.sqrt(config.d_model)
233
  else:
234
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
 
 
235
  self.logit_scale = logit_scale
236
 
237
  def get_input_embeddings(self):
@@ -252,25 +391,63 @@ class MPTForCausalLM(MPTPreTrainedModel):
252
  def get_decoder(self):
253
  return self.transformer
254
 
255
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
256
- return_dict = return_dict if return_dict is not None else self.config.return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  use_cache = use_cache if use_cache is not None else self.config.use_cache
258
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
 
 
 
 
 
 
 
 
 
259
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
260
  if self.logit_scale is not None:
261
  if self.logit_scale == 0:
262
- warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
 
263
  logits *= self.logit_scale
264
  loss = None
265
  if labels is not None:
266
  labels = torch.roll(labels, shifts=-1)
267
  labels[:, -1] = -100
268
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
269
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
 
 
 
 
 
 
 
270
 
271
  def param_init_fn(self, module):
272
- init_fn_name = self.config.init_config['name']
273
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
 
 
 
 
 
274
 
275
  def fsdp_wrap_fn(self, module):
276
  return isinstance(module, MPTBlock)
@@ -278,12 +455,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
278
  def activation_checkpointing_fn(self, module):
279
  return isinstance(module, MPTBlock)
280
 
281
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
 
282
  if inputs_embeds is not None:
283
- raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
284
- attention_mask = kwargs['attention_mask'].bool()
285
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
286
- raise NotImplementedError('MPT does not support generation with right padding.')
 
 
287
  if self.transformer.attn_uses_sequence_id and self.training:
288
  sequence_id = torch.zeros_like(input_ids[:1])
289
  else:
@@ -292,11 +473,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
292
  input_ids = input_ids[:, -1].unsqueeze(-1)
293
  if self.transformer.prefix_lm:
294
  prefix_mask = torch.ones_like(attention_mask)
295
- if kwargs.get('use_cache') == False:
296
- raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
 
 
297
  else:
298
  prefix_mask = None
299
- return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
 
 
 
 
 
 
 
300
 
301
  @staticmethod
302
  def _reorder_cache(past_key_values, beam_idx):
@@ -307,5 +497,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
307
  """
308
  reordered_past = []
309
  for layer_past in past_key_values:
310
- reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
311
- return reordered_past
 
 
 
 
 
5
  import math
6
  import warnings
7
  from typing import List, Optional, Tuple, Union
8
+
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from transformers import (PreTrainedModel, PreTrainedTokenizer,
13
+ PreTrainedTokenizerFast)
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+
17
+ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .attention import attn_bias_shape, build_attn_bias
19
  from .blocks import MPTBlock
 
20
  from .configuration_mpt import MPTConfig
21
+ from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing,
22
+ convert_hf_causal_lm_to_prefix_lm)
23
  from .meta_init_context import init_empty_weights
24
+ from .norm import NORM_CLASS_REGISTRY
25
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
26
+
27
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
28
 
29
  from transformers.utils import logging
30
+
31
  logger = logging.get_logger(__name__)
32
 
33
+
34
  class MPTPreTrainedModel(PreTrainedModel):
35
  config_class = MPTConfig
36
+ base_model_prefix = "model"
37
 
 
38
 
39
+ class MPTModel(MPTPreTrainedModel):
40
  def __init__(self, config: MPTConfig):
41
  config._validate_config()
42
  super().__init__(config)
43
+ self.attn_impl = config.attn_config["attn_impl"]
44
+ self.prefix_lm = config.attn_config["prefix_lm"]
45
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
46
+ self.alibi = config.attn_config["alibi"]
47
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
48
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
49
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
50
+ raise NotImplementedError(
51
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
52
+ )
53
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
54
  self.embedding_fraction = config.embedding_fraction
55
+ self.wte = nn.Embedding(
56
+ config.vocab_size, config.d_model, device=config.init_device
57
+ )
58
  if not self.alibi:
59
+ self.wpe = nn.Embedding(
60
+ config.max_seq_len, config.d_model, device=config.init_device
61
+ )
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
+ self.blocks = nn.ModuleList(
64
+ [
65
+ MPTBlock(device=config.init_device, **config.to_dict())
66
+ for _ in range(config.n_layers)
67
+ ]
68
+ )
69
  self.norm_f = norm_class(config.d_model, device=config.init_device)
70
+ if config.init_device != "meta":
71
  self.apply(self.param_init_fn)
72
  self.is_causal = not self.prefix_lm
73
  self._attn_bias_initialized = False
74
  self.attn_bias = None
75
+ self.attn_bias_shape = attn_bias_shape(
76
+ self.attn_impl,
77
+ config.n_heads,
78
+ config.max_seq_len,
79
+ self.alibi,
80
+ prefix_lm=self.prefix_lm,
81
+ causal=self.is_causal,
82
+ use_sequence_id=self.attn_uses_sequence_id,
83
+ )
84
  if config.no_bias:
85
  for module in self.modules():
86
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
87
  if config.verbose:
88
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
89
+ module.register_parameter("bias", None)
90
  if config.verbose and config.verbose > 2:
91
  print(self)
92
+ if "verbose" not in self.config.init_config:
93
+ self.config.init_config["verbose"] = self.config.verbose
94
+ if self.config.init_config["verbose"] > 1:
95
+ init_fn_name = self.config.init_config["name"]
96
+ warnings.warn(f"Using {init_fn_name} initialization.")
97
  self.gradient_checkpointing = False
98
 
99
  def get_input_embeddings(self):
 
103
  self.wte = value
104
 
105
  @torch.no_grad()
106
+ def _attn_bias(
107
+ self,
108
+ device,
109
+ dtype,
110
+ attention_mask: Optional[torch.ByteTensor] = None,
111
+ prefix_mask: Optional[torch.ByteTensor] = None,
112
+ sequence_id: Optional[torch.LongTensor] = None,
113
+ ):
114
  if not self._attn_bias_initialized:
115
  if self.attn_bias_shape:
116
+ self.attn_bias = torch.zeros(
117
+ self.attn_bias_shape, device=device, dtype=dtype
118
+ )
119
+ self.attn_bias = build_attn_bias(
120
+ self.attn_impl,
121
+ self.attn_bias,
122
+ self.config.n_heads,
123
+ self.config.max_seq_len,
124
+ causal=self.is_causal,
125
+ alibi=self.alibi,
126
+ alibi_bias_max=self.alibi_bias_max,
127
+ )
128
  self._attn_bias_initialized = True
129
+ if self.attn_impl == "flash":
130
  return (self.attn_bias, attention_mask)
131
  if self.attn_bias is not None:
132
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
 
145
  else:
146
  attn_bias = attn_bias[:, :, :, -s_k:]
147
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
148
+ raise ValueError(
149
+ f"attention_mask shape={attention_mask.shape} "
150
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
151
+ )
152
  min_val = torch.finfo(attn_bias.dtype).min
153
+ attn_bias = attn_bias.masked_fill(
154
+ ~attention_mask.view(-1, 1, 1, s_k), min_val
155
+ )
156
  return (attn_bias, None)
157
 
158
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
159
  (s_k, s_q) = attn_bias.shape[-2:]
160
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
161
+ raise ValueError(
162
+ "attn_bias does not match the expected shape. "
163
+ + f"The last two dimensions should both be {self.config.max_length} "
164
+ + f"but are {s_k} and {s_q}."
165
+ )
166
  seq_len = prefix_mask.shape[-1]
167
  if seq_len > self.config.max_seq_len:
168
+ raise ValueError(
169
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
170
+ )
171
  attn_bias = attn_bias[..., :seq_len, :seq_len]
172
+ causal = torch.tril(
173
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
174
+ ).view(1, 1, seq_len, seq_len)
175
  prefix = prefix_mask.view(-1, 1, 1, seq_len)
176
  cannot_attend = ~torch.logical_or(causal, prefix.bool())
177
  min_val = torch.finfo(attn_bias.dtype).min
178
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
179
  return attn_bias
180
 
181
+ def _apply_sequence_id(
182
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
183
+ ):
184
  seq_len = sequence_id.shape[-1]
185
  if seq_len > self.config.max_seq_len:
186
+ raise ValueError(
187
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
188
+ )
189
  attn_bias = attn_bias[..., :seq_len, :seq_len]
190
+ cannot_attend = torch.logical_not(
191
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
192
+ ).unsqueeze(1)
193
  min_val = torch.finfo(attn_bias.dtype).min
194
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
195
  return attn_bias
196
 
197
+ def forward(
198
+ self,
199
+ input_ids: torch.LongTensor,
200
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
201
+ attention_mask: Optional[torch.ByteTensor] = None,
202
+ prefix_mask: Optional[torch.ByteTensor] = None,
203
+ sequence_id: Optional[torch.LongTensor] = None,
204
+ return_dict: Optional[bool] = None,
205
+ output_attentions: Optional[bool] = None,
206
+ output_hidden_states: Optional[bool] = None,
207
+ use_cache: Optional[bool] = None,
208
+ tok_emb: Optional[torch.FloatTensor] = None,
209
+ ):
210
+ return_dict = (
211
+ return_dict if return_dict is not None else self.config.return_dict
212
+ )
213
  use_cache = use_cache if use_cache is not None else self.config.use_cache
214
 
215
  if self.gradient_checkpointing and self.training:
 
223
  if prefix_mask is not None:
224
  prefix_mask = prefix_mask.bool()
225
  if not return_dict:
226
+ raise NotImplementedError(
227
+ "return_dict False is not implemented yet for MPT"
228
+ )
229
  if output_attentions:
230
+ raise NotImplementedError(
231
+ "output_attentions is not implemented yet for MPT"
232
+ )
233
+ if (
234
+ attention_mask is not None
235
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
236
+ and self.training
237
+ ):
238
+ raise NotImplementedError(
239
+ "MPT does not support training with left padding."
240
+ )
241
  if self.prefix_lm and prefix_mask is None:
242
+ raise ValueError(
243
+ "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
244
+ )
245
  if self.training:
246
  if self.attn_uses_sequence_id and sequence_id is None:
247
+ raise ValueError(
248
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
249
+ + "and the model is in train mode."
250
+ )
251
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
252
+ warnings.warn(
253
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
254
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
255
+ )
256
  if input_ids is not None:
257
  S = input_ids.size(1)
258
+ assert (
259
+ S <= self.config.max_seq_len
260
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
261
  tok_emb = self.wte(input_ids)
262
  else:
263
  assert tok_emb is not None
 
268
  past_position = 0
269
  if past_key_values is not None:
270
  if len(past_key_values) != self.config.n_layers:
271
+ raise ValueError(
272
+ f"past_key_values must provide a past_key_value for each attention "
273
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
274
+ )
275
  past_position = past_key_values[0][0].size(1)
276
  if S + past_position > self.config.max_seq_len:
277
+ raise ValueError(
278
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
279
+ )
280
+ pos = torch.arange(
281
+ past_position,
282
+ S + past_position,
283
+ dtype=torch.long,
284
+ device=input_ids.device,
285
+ ).unsqueeze(0)
286
  if attention_mask is not None:
287
+ pos = torch.clamp(
288
+ pos
289
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
290
+ :, past_position:
291
+ ],
292
+ min=0,
293
+ )
294
  pos_emb = self.wpe(pos)
295
  x = tok_emb + pos_emb
296
  if self.embedding_fraction == 1:
297
  x = self.emb_drop(x)
298
  else:
299
+ x_shrunk = x * self.embedding_fraction + x.detach() * (
300
+ 1 - self.embedding_fraction
301
+ )
302
  assert isinstance(self.emb_drop, nn.Module)
303
  x = self.emb_drop(x_shrunk)
304
+ (attn_bias, attention_mask) = self._attn_bias(
305
+ device=x.device,
306
+ dtype=x.dtype,
307
+ attention_mask=attention_mask,
308
+ prefix_mask=prefix_mask,
309
+ sequence_id=sequence_id,
310
+ )
311
  if use_cache and past_key_values is None:
312
  past_key_values = [() for _ in range(self.config.n_layers)]
313
  all_hidden_states = () if output_hidden_states else None
314
+ for b_idx, block in enumerate(self.blocks):
315
  if output_hidden_states:
316
  assert all_hidden_states is not None
317
  all_hidden_states = all_hidden_states + (x,)
318
+ past_key_value = (
319
+ past_key_values[b_idx] if past_key_values is not None else None
320
+ )
321
  if self.gradient_checkpointing and self.training:
322
  (x, past_key_value) = torch.utils.checkpoint.checkpoint(
323
+ block, x, past_key_value, attn_bias, attention_mask, self.is_causal
 
324
  )
325
  else:
326
+ (x, past_key_value) = block(
327
+ x,
328
+ past_key_value=past_key_value,
329
+ attn_bias=attn_bias,
330
+ attention_mask=attention_mask,
331
+ is_causal=self.is_causal,
332
+ )
333
  if past_key_values is not None:
334
  past_key_values[b_idx] = past_key_value
335
  x = self.norm_f(x)
336
+ return BaseModelOutputWithPast(
337
+ last_hidden_state=x,
338
+ past_key_values=past_key_values,
339
+ hidden_states=all_hidden_states,
340
+ )
341
 
342
  def param_init_fn(self, module):
343
+ init_fn_name = self.config.init_config["name"]
344
+ MODEL_INIT_REGISTRY[init_fn_name](
345
+ module=module,
346
+ n_layers=self.config.n_layers,
347
+ d_model=self.config.d_model,
348
+ **self.config.init_config,
349
+ )
350
 
351
  def fsdp_wrap_fn(self, module):
352
  return isinstance(module, MPTBlock)
 
354
  def activation_checkpointing_fn(self, module):
355
  return isinstance(module, MPTBlock)
356
 
 
357
 
358
+ class MPTForCausalLM(MPTPreTrainedModel):
359
  def __init__(self, config: MPTConfig):
360
  super().__init__(config)
361
  if not config.tie_word_embeddings:
362
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
363
  self.transformer = MPTModel(config)
364
  self.logit_scale = None
365
  if config.logit_scale is not None:
366
  logit_scale = config.logit_scale
367
  if isinstance(logit_scale, str):
368
+ if logit_scale == "inv_sqrt_d_model":
369
  logit_scale = 1 / math.sqrt(config.d_model)
370
  else:
371
+ raise ValueError(
372
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
373
+ )
374
  self.logit_scale = logit_scale
375
 
376
  def get_input_embeddings(self):
 
391
  def get_decoder(self):
392
  return self.transformer
393
 
394
+ def forward(
395
+ self,
396
+ input_ids: torch.LongTensor,
397
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
398
+ attention_mask: Optional[torch.ByteTensor] = None,
399
+ prefix_mask: Optional[torch.ByteTensor] = None,
400
+ sequence_id: Optional[torch.LongTensor] = None,
401
+ labels: Optional[torch.LongTensor] = None,
402
+ return_dict: Optional[bool] = None,
403
+ output_attentions: Optional[bool] = None,
404
+ output_hidden_states: Optional[bool] = None,
405
+ use_cache: Optional[bool] = None,
406
+ ):
407
+ return_dict = (
408
+ return_dict if return_dict is not None else self.config.return_dict
409
+ )
410
  use_cache = use_cache if use_cache is not None else self.config.use_cache
411
+ outputs = self.transformer(
412
+ input_ids=input_ids,
413
+ past_key_values=past_key_values,
414
+ attention_mask=attention_mask,
415
+ prefix_mask=prefix_mask,
416
+ sequence_id=sequence_id,
417
+ return_dict=return_dict,
418
+ output_attentions=output_attentions,
419
+ output_hidden_states=output_hidden_states,
420
+ use_cache=use_cache,
421
+ )
422
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
423
  if self.logit_scale is not None:
424
  if self.logit_scale == 0:
425
+ warnings.warn(
426
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
427
+ )
428
  logits *= self.logit_scale
429
  loss = None
430
  if labels is not None:
431
  labels = torch.roll(labels, shifts=-1)
432
  labels[:, -1] = -100
433
+ loss = F.cross_entropy(
434
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
435
+ )
436
+ return CausalLMOutputWithPast(
437
+ loss=loss,
438
+ logits=logits,
439
+ past_key_values=outputs.past_key_values,
440
+ hidden_states=outputs.hidden_states,
441
+ )
442
 
443
  def param_init_fn(self, module):
444
+ init_fn_name = self.config.init_config["name"]
445
+ MODEL_INIT_REGISTRY[init_fn_name](
446
+ module=module,
447
+ n_layers=self.config.n_layers,
448
+ d_model=self.config.d_model,
449
+ **self.config.init_config,
450
+ )
451
 
452
  def fsdp_wrap_fn(self, module):
453
  return isinstance(module, MPTBlock)
 
455
  def activation_checkpointing_fn(self, module):
456
  return isinstance(module, MPTBlock)
457
 
458
+ def prepare_inputs_for_generation(
459
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
460
+ ):
461
  if inputs_embeds is not None:
462
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
463
+ attention_mask = kwargs["attention_mask"].bool()
464
  if attention_mask[:, -1].sum() != attention_mask.shape[0]:
465
+ raise NotImplementedError(
466
+ "MPT does not support generation with right padding."
467
+ )
468
  if self.transformer.attn_uses_sequence_id and self.training:
469
  sequence_id = torch.zeros_like(input_ids[:1])
470
  else:
 
473
  input_ids = input_ids[:, -1].unsqueeze(-1)
474
  if self.transformer.prefix_lm:
475
  prefix_mask = torch.ones_like(attention_mask)
476
+ if kwargs.get("use_cache") == False:
477
+ raise NotImplementedError(
478
+ "MPT with prefix_lm=True does not support use_cache=False."
479
+ )
480
  else:
481
  prefix_mask = None
482
+ return {
483
+ "input_ids": input_ids,
484
+ "attention_mask": attention_mask,
485
+ "prefix_mask": prefix_mask,
486
+ "sequence_id": sequence_id,
487
+ "past_key_values": past_key_values,
488
+ "use_cache": kwargs.get("use_cache", True),
489
+ }
490
 
491
  @staticmethod
492
  def _reorder_cache(past_key_values, beam_idx):
 
497
  """
498
  reordered_past = []
499
  for layer_past in past_key_values:
500
+ reordered_past += [
501
+ tuple(
502
+ (past_state.index_select(0, beam_idx) for past_state in layer_past)
503
+ )
504
+ ]
505
+ return reordered_past
model/llava/model/mpt/norm.py CHANGED
@@ -1,28 +1,55 @@
1
  import torch
2
 
 
3
  def _cast_if_autocast_enabled(tensor):
4
  if torch.is_autocast_enabled():
5
- if tensor.device.type == 'cuda':
6
  dtype = torch.get_autocast_gpu_dtype()
7
- elif tensor.device.type == 'cpu':
8
  dtype = torch.get_autocast_cpu_dtype()
9
  else:
10
  raise NotImplementedError()
11
  return tensor.to(dtype=dtype)
12
  return tensor
13
 
14
- class LPLayerNorm(torch.nn.LayerNorm):
15
 
16
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17
- super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def forward(self, x):
20
  module_device = x.device
21
  downcast_x = _cast_if_autocast_enabled(x)
22
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
 
 
 
 
 
 
24
  with torch.autocast(enabled=False, device_type=module_device.type):
25
- return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
 
 
 
 
 
 
 
26
 
27
  def rms_norm(x, weight=None, eps=1e-05):
28
  output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
@@ -30,27 +57,50 @@ def rms_norm(x, weight=None, eps=1e-05):
30
  return output * weight
31
  return output
32
 
33
- class RMSNorm(torch.nn.Module):
34
 
35
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
 
 
 
36
  super().__init__()
37
  self.eps = eps
38
  if weight:
39
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
 
 
40
  else:
41
- self.register_parameter('weight', None)
42
 
43
  def forward(self, x):
44
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45
 
46
- class LPRMSNorm(RMSNorm):
47
 
48
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49
- super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
 
 
 
 
 
 
 
 
 
50
 
51
  def forward(self, x):
52
  downcast_x = _cast_if_autocast_enabled(x)
53
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
 
 
 
 
54
  with torch.autocast(enabled=False, device_type=x.device.type):
55
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56
- NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
 
 
 
 
 
 
 
 
1
  import torch
2
 
3
+
4
  def _cast_if_autocast_enabled(tensor):
5
  if torch.is_autocast_enabled():
6
+ if tensor.device.type == "cuda":
7
  dtype = torch.get_autocast_gpu_dtype()
8
+ elif tensor.device.type == "cpu":
9
  dtype = torch.get_autocast_cpu_dtype()
10
  else:
11
  raise NotImplementedError()
12
  return tensor.to(dtype=dtype)
13
  return tensor
14
 
 
15
 
16
+ class LPLayerNorm(torch.nn.LayerNorm):
17
+ def __init__(
18
+ self,
19
+ normalized_shape,
20
+ eps=1e-05,
21
+ elementwise_affine=True,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ super().__init__(
26
+ normalized_shape=normalized_shape,
27
+ eps=eps,
28
+ elementwise_affine=elementwise_affine,
29
+ device=device,
30
+ dtype=dtype,
31
+ )
32
 
33
  def forward(self, x):
34
  module_device = x.device
35
  downcast_x = _cast_if_autocast_enabled(x)
36
+ downcast_weight = (
37
+ _cast_if_autocast_enabled(self.weight)
38
+ if self.weight is not None
39
+ else self.weight
40
+ )
41
+ downcast_bias = (
42
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
43
+ )
44
  with torch.autocast(enabled=False, device_type=module_device.type):
45
+ return torch.nn.functional.layer_norm(
46
+ downcast_x,
47
+ self.normalized_shape,
48
+ downcast_weight,
49
+ downcast_bias,
50
+ self.eps,
51
+ )
52
+
53
 
54
  def rms_norm(x, weight=None, eps=1e-05):
55
  output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
 
57
  return output * weight
58
  return output
59
 
 
60
 
61
+ class RMSNorm(torch.nn.Module):
62
+ def __init__(
63
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
64
+ ):
65
  super().__init__()
66
  self.eps = eps
67
  if weight:
68
+ self.weight = torch.nn.Parameter(
69
+ torch.ones(normalized_shape, dtype=dtype, device=device)
70
+ )
71
  else:
72
+ self.register_parameter("weight", None)
73
 
74
  def forward(self, x):
75
  return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
76
 
 
77
 
78
+ class LPRMSNorm(RMSNorm):
79
+ def __init__(
80
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
81
+ ):
82
+ super().__init__(
83
+ normalized_shape=normalized_shape,
84
+ eps=eps,
85
+ weight=weight,
86
+ dtype=dtype,
87
+ device=device,
88
+ )
89
 
90
  def forward(self, x):
91
  downcast_x = _cast_if_autocast_enabled(x)
92
+ downcast_weight = (
93
+ _cast_if_autocast_enabled(self.weight)
94
+ if self.weight is not None
95
+ else self.weight
96
+ )
97
  with torch.autocast(enabled=False, device_type=x.device.type):
98
  return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
99
+
100
+
101
+ NORM_CLASS_REGISTRY = {
102
+ "layernorm": torch.nn.LayerNorm,
103
+ "low_precision_layernorm": LPLayerNorm,
104
+ "rmsnorm": RMSNorm,
105
+ "low_precision_rmsnorm": LPRMSNorm,
106
+ }
model/llava/model/mpt/param_init_fns.py CHANGED
@@ -3,101 +3,139 @@ import warnings
3
  from collections.abc import Sequence
4
  from functools import partial
5
  from typing import Optional, Tuple, Union
 
6
  import torch
7
  from torch import nn
 
8
  from .norm import NORM_CLASS_REGISTRY
9
 
10
- def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
 
11
  del kwargs
12
  if verbose > 1:
13
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
14
- if hasattr(module, 'reset_parameters'):
15
  module.reset_parameters()
16
 
 
17
  def fused_init_helper_(module: nn.Module, init_fn_):
18
- _fused = getattr(module, '_fused', None)
19
  if _fused is None:
20
- raise RuntimeError(f'Internal logic error')
21
  (dim, splits) = _fused
22
  splits = (0, *splits, module.weight.size(dim))
23
- for (s, e) in zip(splits[:-1], splits[1:]):
24
  slice_indices = [slice(None)] * module.weight.ndim
25
  slice_indices[dim] = slice(s, e)
26
  init_fn_(module.weight[slice_indices])
27
 
28
- def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
29
  del kwargs
30
  if verbose > 1:
31
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
32
  init_div_is_residual = init_div_is_residual
33
  if init_div_is_residual is False:
34
  div_is_residual = 1.0
35
  elif init_div_is_residual is True:
36
  div_is_residual = math.sqrt(2 * n_layers)
37
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
 
 
38
  div_is_residual = init_div_is_residual
39
  elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40
  div_is_residual = float(init_div_is_residual)
41
  else:
42
  div_is_residual = 1.0
43
- raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
 
 
44
  if init_div_is_residual is not False:
45
  if verbose > 1:
46
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
 
 
 
47
  if isinstance(module, nn.Linear):
48
- if hasattr(module, '_fused'):
49
  fused_init_helper_(module, init_fn_)
50
  else:
51
  init_fn_(module.weight)
52
  if module.bias is not None:
53
  torch.nn.init.zeros_(module.bias)
54
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55
  with torch.no_grad():
56
  module.weight.div_(div_is_residual)
57
  elif isinstance(module, nn.Embedding):
58
  if emb_init_std is not None:
59
  std = emb_init_std
60
  if std == 0:
61
- warnings.warn(f'Embedding layer initialized to 0.')
62
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63
  if verbose > 1:
64
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
 
 
65
  elif emb_init_uniform_lim is not None:
66
  lim = emb_init_uniform_lim
67
  if isinstance(lim, Sequence):
68
  if len(lim) > 2:
69
- raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
 
 
70
  if lim[0] == lim[1]:
71
- warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72
  else:
73
  if lim == 0:
74
- warnings.warn(f'Embedding layer initialized to 0.')
75
  lim = [-lim, lim]
76
  (a, b) = lim
77
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78
  if verbose > 1:
79
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
 
 
80
  else:
81
  emb_init_fn_ = init_fn_
82
  emb_init_fn_(module.weight)
83
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84
  if verbose > 1:
85
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86
- if hasattr(module, 'weight') and module.weight is not None:
 
 
87
  torch.nn.init.ones_(module.weight)
88
- if hasattr(module, 'bias') and module.bias is not None:
89
  torch.nn.init.zeros_(module.bias)
90
  elif isinstance(module, nn.MultiheadAttention):
91
  if module._qkv_same_embed_dim:
92
  assert module.in_proj_weight is not None
93
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
 
 
 
 
94
  assert d_model is not None
95
  _d = d_model
96
  splits = (0, _d, 2 * _d, 3 * _d)
97
- for (s, e) in zip(splits[:-1], splits[1:]):
98
  init_fn_(module.in_proj_weight[s:e])
99
  else:
100
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
 
 
 
 
101
  assert module.in_proj_weight is None
102
  init_fn_(module.q_proj_weight)
103
  init_fn_(module.k_proj_weight)
@@ -109,37 +147,112 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
109
  if module.bias_v is not None:
110
  torch.nn.init.zeros_(module.bias_v)
111
  init_fn_(module.out_proj.weight)
112
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
 
 
113
  with torch.no_grad():
114
  module.out_proj.weight.div_(div_is_residual)
115
  if module.out_proj.bias is not None:
116
  torch.nn.init.zeros_(module.out_proj.bias)
117
  else:
118
  for _ in module.parameters(recurse=False):
119
- raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
 
 
 
120
 
121
  def _normal_init_(std, mean=0.0):
122
  return partial(torch.nn.init.normal_, mean=mean, std=std)
123
 
124
- def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
125
  del kwargs
126
  init_fn_ = _normal_init_(std=std)
127
  if verbose > 1:
128
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
130
 
131
- def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
132
  del kwargs
133
  if init_std is None:
134
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
138
  del kwargs
139
  std = math.sqrt(2 / (5 * d_model))
140
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
141
 
142
- def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
143
  """From section 2.3.1 of GPT-NeoX-20B:
144
 
145
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -149,33 +262,158 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
149
  del kwargs
150
  residual_div = n_layers / math.sqrt(10)
151
  if verbose > 1:
152
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
153
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
154
 
155
- def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  del kwargs
157
  if verbose > 1:
158
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
163
  del kwargs
164
  if verbose > 1:
165
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
170
  del kwargs
171
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172
  if verbose > 1:
173
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
 
 
 
 
 
 
 
 
 
 
177
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178
  if verbose > 1:
179
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181
- MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from collections.abc import Sequence
4
  from functools import partial
5
  from typing import Optional, Tuple, Union
6
+
7
  import torch
8
  from torch import nn
9
+
10
  from .norm import NORM_CLASS_REGISTRY
11
 
12
+
13
+ def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
14
  del kwargs
15
  if verbose > 1:
16
  warnings.warn(f"Initializing network using module's reset_parameters attribute")
17
+ if hasattr(module, "reset_parameters"):
18
  module.reset_parameters()
19
 
20
+
21
  def fused_init_helper_(module: nn.Module, init_fn_):
22
+ _fused = getattr(module, "_fused", None)
23
  if _fused is None:
24
+ raise RuntimeError(f"Internal logic error")
25
  (dim, splits) = _fused
26
  splits = (0, *splits, module.weight.size(dim))
27
+ for s, e in zip(splits[:-1], splits[1:]):
28
  slice_indices = [slice(None)] * module.weight.ndim
29
  slice_indices[dim] = slice(s, e)
30
  init_fn_(module.weight[slice_indices])
31
 
32
+
33
+ def generic_param_init_fn_(
34
+ module: nn.Module,
35
+ init_fn_,
36
+ n_layers: int,
37
+ d_model: Optional[int] = None,
38
+ init_div_is_residual: Union[int, float, str, bool] = True,
39
+ emb_init_std: Optional[float] = None,
40
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
41
+ verbose: int = 0,
42
+ **kwargs,
43
+ ):
44
  del kwargs
45
  if verbose > 1:
46
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
47
  init_div_is_residual = init_div_is_residual
48
  if init_div_is_residual is False:
49
  div_is_residual = 1.0
50
  elif init_div_is_residual is True:
51
  div_is_residual = math.sqrt(2 * n_layers)
52
+ elif isinstance(init_div_is_residual, float) or isinstance(
53
+ init_div_is_residual, int
54
+ ):
55
  div_is_residual = init_div_is_residual
56
  elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
57
  div_is_residual = float(init_div_is_residual)
58
  else:
59
  div_is_residual = 1.0
60
+ raise ValueError(
61
+ f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
62
+ )
63
  if init_div_is_residual is not False:
64
  if verbose > 1:
65
+ warnings.warn(
66
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
67
+ + f"Set `init_div_is_residual: false` in init config to disable this."
68
+ )
69
  if isinstance(module, nn.Linear):
70
+ if hasattr(module, "_fused"):
71
  fused_init_helper_(module, init_fn_)
72
  else:
73
  init_fn_(module.weight)
74
  if module.bias is not None:
75
  torch.nn.init.zeros_(module.bias)
76
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
77
  with torch.no_grad():
78
  module.weight.div_(div_is_residual)
79
  elif isinstance(module, nn.Embedding):
80
  if emb_init_std is not None:
81
  std = emb_init_std
82
  if std == 0:
83
+ warnings.warn(f"Embedding layer initialized to 0.")
84
  emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
85
  if verbose > 1:
86
+ warnings.warn(
87
+ f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}."
88
+ )
89
  elif emb_init_uniform_lim is not None:
90
  lim = emb_init_uniform_lim
91
  if isinstance(lim, Sequence):
92
  if len(lim) > 2:
93
+ raise ValueError(
94
+ f"Uniform init requires a min and a max limit. User input: {lim}."
95
+ )
96
  if lim[0] == lim[1]:
97
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
98
  else:
99
  if lim == 0:
100
+ warnings.warn(f"Embedding layer initialized to 0.")
101
  lim = [-lim, lim]
102
  (a, b) = lim
103
  emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
104
  if verbose > 1:
105
+ warnings.warn(
106
+ f"Embedding layer initialized using uniform distribution in range {lim}."
107
+ )
108
  else:
109
  emb_init_fn_ = init_fn_
110
  emb_init_fn_(module.weight)
111
  elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
112
  if verbose > 1:
113
+ warnings.warn(
114
+ f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0."
115
+ )
116
+ if hasattr(module, "weight") and module.weight is not None:
117
  torch.nn.init.ones_(module.weight)
118
+ if hasattr(module, "bias") and module.bias is not None:
119
  torch.nn.init.zeros_(module.bias)
120
  elif isinstance(module, nn.MultiheadAttention):
121
  if module._qkv_same_embed_dim:
122
  assert module.in_proj_weight is not None
123
+ assert (
124
+ module.q_proj_weight is None
125
+ and module.k_proj_weight is None
126
+ and (module.v_proj_weight is None)
127
+ )
128
  assert d_model is not None
129
  _d = d_model
130
  splits = (0, _d, 2 * _d, 3 * _d)
131
+ for s, e in zip(splits[:-1], splits[1:]):
132
  init_fn_(module.in_proj_weight[s:e])
133
  else:
134
+ assert (
135
+ module.q_proj_weight is not None
136
+ and module.k_proj_weight is not None
137
+ and (module.v_proj_weight is not None)
138
+ )
139
  assert module.in_proj_weight is None
140
  init_fn_(module.q_proj_weight)
141
  init_fn_(module.k_proj_weight)
 
147
  if module.bias_v is not None:
148
  torch.nn.init.zeros_(module.bias_v)
149
  init_fn_(module.out_proj.weight)
150
+ if init_div_is_residual is not False and getattr(
151
+ module.out_proj, "_is_residual", False
152
+ ):
153
  with torch.no_grad():
154
  module.out_proj.weight.div_(div_is_residual)
155
  if module.out_proj.bias is not None:
156
  torch.nn.init.zeros_(module.out_proj.bias)
157
  else:
158
  for _ in module.parameters(recurse=False):
159
+ raise NotImplementedError(
160
+ f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
161
+ )
162
+
163
 
164
  def _normal_init_(std, mean=0.0):
165
  return partial(torch.nn.init.normal_, mean=mean, std=std)
166
 
167
+
168
+ def _normal_param_init_fn_(
169
+ module: nn.Module,
170
+ std: float,
171
+ n_layers: int,
172
+ d_model: Optional[int] = None,
173
+ init_div_is_residual: Union[int, float, str, bool] = True,
174
+ emb_init_std: Optional[float] = None,
175
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
176
+ verbose: int = 0,
177
+ **kwargs,
178
+ ):
179
  del kwargs
180
  init_fn_ = _normal_init_(std=std)
181
  if verbose > 1:
182
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
183
+ generic_param_init_fn_(
184
+ module=module,
185
+ init_fn_=init_fn_,
186
+ d_model=d_model,
187
+ n_layers=n_layers,
188
+ init_div_is_residual=init_div_is_residual,
189
+ emb_init_std=emb_init_std,
190
+ emb_init_uniform_lim=emb_init_uniform_lim,
191
+ verbose=verbose,
192
+ )
193
+
194
 
195
+ def baseline_param_init_fn_(
196
+ module: nn.Module,
197
+ init_std: float,
198
+ n_layers: int,
199
+ d_model: Optional[int] = None,
200
+ init_div_is_residual: Union[int, float, str, bool] = True,
201
+ emb_init_std: Optional[float] = None,
202
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
203
+ verbose: int = 0,
204
+ **kwargs,
205
+ ):
206
  del kwargs
207
  if init_std is None:
208
+ raise ValueError(
209
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
210
+ )
211
+ _normal_param_init_fn_(
212
+ module=module,
213
+ std=init_std,
214
+ d_model=d_model,
215
+ n_layers=n_layers,
216
+ init_div_is_residual=init_div_is_residual,
217
+ emb_init_std=emb_init_std,
218
+ emb_init_uniform_lim=emb_init_uniform_lim,
219
+ verbose=verbose,
220
+ )
221
 
222
+
223
+ def small_param_init_fn_(
224
+ module: nn.Module,
225
+ n_layers: int,
226
+ d_model: int,
227
+ init_div_is_residual: Union[int, float, str, bool] = True,
228
+ emb_init_std: Optional[float] = None,
229
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
230
+ verbose: int = 0,
231
+ **kwargs,
232
+ ):
233
  del kwargs
234
  std = math.sqrt(2 / (5 * d_model))
235
+ _normal_param_init_fn_(
236
+ module=module,
237
+ std=std,
238
+ d_model=d_model,
239
+ n_layers=n_layers,
240
+ init_div_is_residual=init_div_is_residual,
241
+ emb_init_std=emb_init_std,
242
+ emb_init_uniform_lim=emb_init_uniform_lim,
243
+ verbose=verbose,
244
+ )
245
+
246
 
247
+ def neox_param_init_fn_(
248
+ module: nn.Module,
249
+ n_layers: int,
250
+ d_model: int,
251
+ emb_init_std: Optional[float] = None,
252
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
253
+ verbose: int = 0,
254
+ **kwargs,
255
+ ):
256
  """From section 2.3.1 of GPT-NeoX-20B:
257
 
258
  An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
 
262
  del kwargs
263
  residual_div = n_layers / math.sqrt(10)
264
  if verbose > 1:
265
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
266
+ small_param_init_fn_(
267
+ module=module,
268
+ d_model=d_model,
269
+ n_layers=n_layers,
270
+ init_div_is_residual=residual_div,
271
+ emb_init_std=emb_init_std,
272
+ emb_init_uniform_lim=emb_init_uniform_lim,
273
+ verbose=verbose,
274
+ )
275
 
276
+
277
+ def kaiming_uniform_param_init_fn_(
278
+ module: nn.Module,
279
+ n_layers: int,
280
+ d_model: Optional[int] = None,
281
+ init_div_is_residual: Union[int, float, str, bool] = True,
282
+ emb_init_std: Optional[float] = None,
283
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
284
+ init_gain: float = 0,
285
+ fan_mode: str = "fan_in",
286
+ init_nonlinearity: str = "leaky_relu",
287
+ verbose: int = 0,
288
+ **kwargs,
289
+ ):
290
  del kwargs
291
  if verbose > 1:
292
+ warnings.warn(
293
+ f"Using nn.init.kaiming_uniform_ init fn with parameters: "
294
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
295
+ )
296
+ kaiming_uniform_ = partial(
297
+ nn.init.kaiming_uniform_,
298
+ a=init_gain,
299
+ mode=fan_mode,
300
+ nonlinearity=init_nonlinearity,
301
+ )
302
+ generic_param_init_fn_(
303
+ module=module,
304
+ init_fn_=kaiming_uniform_,
305
+ d_model=d_model,
306
+ n_layers=n_layers,
307
+ init_div_is_residual=init_div_is_residual,
308
+ emb_init_std=emb_init_std,
309
+ emb_init_uniform_lim=emb_init_uniform_lim,
310
+ verbose=verbose,
311
+ )
312
+
313
 
314
+ def kaiming_normal_param_init_fn_(
315
+ module: nn.Module,
316
+ n_layers: int,
317
+ d_model: Optional[int] = None,
318
+ init_div_is_residual: Union[int, float, str, bool] = True,
319
+ emb_init_std: Optional[float] = None,
320
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
321
+ init_gain: float = 0,
322
+ fan_mode: str = "fan_in",
323
+ init_nonlinearity: str = "leaky_relu",
324
+ verbose: int = 0,
325
+ **kwargs,
326
+ ):
327
  del kwargs
328
  if verbose > 1:
329
+ warnings.warn(
330
+ f"Using nn.init.kaiming_normal_ init fn with parameters: "
331
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
332
+ )
333
+ kaiming_normal_ = partial(
334
+ torch.nn.init.kaiming_normal_,
335
+ a=init_gain,
336
+ mode=fan_mode,
337
+ nonlinearity=init_nonlinearity,
338
+ )
339
+ generic_param_init_fn_(
340
+ module=module,
341
+ init_fn_=kaiming_normal_,
342
+ d_model=d_model,
343
+ n_layers=n_layers,
344
+ init_div_is_residual=init_div_is_residual,
345
+ emb_init_std=emb_init_std,
346
+ emb_init_uniform_lim=emb_init_uniform_lim,
347
+ verbose=verbose,
348
+ )
349
 
350
+
351
+ def xavier_uniform_param_init_fn_(
352
+ module: nn.Module,
353
+ n_layers: int,
354
+ d_model: Optional[int] = None,
355
+ init_div_is_residual: Union[int, float, str, bool] = True,
356
+ emb_init_std: Optional[float] = None,
357
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
358
+ init_gain: float = 0,
359
+ verbose: int = 0,
360
+ **kwargs,
361
+ ):
362
  del kwargs
363
  xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
364
  if verbose > 1:
365
+ warnings.warn(
366
+ f"Using torch.nn.init.xavier_uniform_ init fn with parameters: "
367
+ + f"gain={init_gain}"
368
+ )
369
+ generic_param_init_fn_(
370
+ module=module,
371
+ init_fn_=xavier_uniform_,
372
+ d_model=d_model,
373
+ n_layers=n_layers,
374
+ init_div_is_residual=init_div_is_residual,
375
+ emb_init_std=emb_init_std,
376
+ emb_init_uniform_lim=emb_init_uniform_lim,
377
+ verbose=verbose,
378
+ )
379
+
380
 
381
+ def xavier_normal_param_init_fn_(
382
+ module: nn.Module,
383
+ n_layers: int,
384
+ d_model: Optional[int] = None,
385
+ init_div_is_residual: Union[int, float, str, bool] = True,
386
+ emb_init_std: Optional[float] = None,
387
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
388
+ init_gain: float = 0,
389
+ verbose: int = 0,
390
+ **kwargs,
391
+ ):
392
  xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
393
  if verbose > 1:
394
+ warnings.warn(
395
+ f"Using torch.nn.init.xavier_normal_ init fn with parameters: "
396
+ + f"gain={init_gain}"
397
+ )
398
+ generic_param_init_fn_(
399
+ module=module,
400
+ init_fn_=xavier_normal_,
401
+ d_model=d_model,
402
+ n_layers=n_layers,
403
+ init_div_is_residual=init_div_is_residual,
404
+ emb_init_std=emb_init_std,
405
+ emb_init_uniform_lim=emb_init_uniform_lim,
406
+ verbose=verbose,
407
+ )
408
+
409
+
410
+ MODEL_INIT_REGISTRY = {
411
+ "default_": torch_default_param_init_fn_,
412
+ "baseline_": baseline_param_init_fn_,
413
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
414
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
415
+ "neox_init_": neox_param_init_fn_,
416
+ "small_init_": small_param_init_fn_,
417
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
418
+ "xavier_normal_": xavier_normal_param_init_fn_,
419
+ }
model/llava/model/utils.py CHANGED
@@ -5,16 +5,20 @@ from transformers import AutoConfig, StoppingCriteria
5
 
6
  def auto_upgrade(config):
7
  cfg = AutoConfig.from_pretrained(config)
8
- if 'llava' in config and 'llava' not in cfg.model_type:
9
- assert cfg.model_type == 'llama'
10
- print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
11
- print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
 
 
 
 
12
  confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
13
  if confirm.lower() in ["y", "yes"]:
14
  print("Upgrading checkpoint...")
15
  assert len(cfg.architectures) == 1
16
  setattr(cfg.__class__, "model_type", "llava")
17
- cfg.architectures[0] = 'LlavaLlamaForCausalLM'
18
  cfg.save_pretrained(config)
19
  print("Checkpoint upgraded.")
20
  else:
@@ -22,24 +26,31 @@ def auto_upgrade(config):
22
  exit(1)
23
 
24
 
25
-
26
  class KeywordsStoppingCriteria(StoppingCriteria):
27
  def __init__(self, keywords, tokenizer, input_ids):
28
  self.keywords = keywords
29
  self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
30
- self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
 
 
 
 
31
  self.tokenizer = tokenizer
32
  self.start_len = None
33
  self.input_ids = input_ids
34
 
35
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
 
36
  if self.start_len is None:
37
  self.start_len = self.input_ids.shape[1]
38
  else:
39
  for keyword_id in self.keyword_ids:
40
  if output_ids[0, -1] == keyword_id:
41
  return True
42
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
 
 
43
  for keyword in self.keywords:
44
  if keyword in outputs:
45
  return True
@@ -50,7 +61,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
50
 
51
  # # if output_ids[0, -1] == keyword_id:
52
  # # return True
53
-
54
  # print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len))
55
 
56
  # print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:])
 
5
 
6
  def auto_upgrade(config):
7
  cfg = AutoConfig.from_pretrained(config)
8
+ if "llava" in config and "llava" not in cfg.model_type:
9
+ assert cfg.model_type == "llama"
10
+ print(
11
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
12
+ )
13
+ print(
14
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
15
+ )
16
  confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
17
  if confirm.lower() in ["y", "yes"]:
18
  print("Upgrading checkpoint...")
19
  assert len(cfg.architectures) == 1
20
  setattr(cfg.__class__, "model_type", "llava")
21
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
22
  cfg.save_pretrained(config)
23
  print("Checkpoint upgraded.")
24
  else:
 
26
  exit(1)
27
 
28
 
 
29
  class KeywordsStoppingCriteria(StoppingCriteria):
30
  def __init__(self, keywords, tokenizer, input_ids):
31
  self.keywords = keywords
32
  self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
33
+ self.keyword_ids = [
34
+ keyword_id[0]
35
+ for keyword_id in self.keyword_ids
36
+ if type(keyword_id) is list and len(keyword_id) == 1
37
+ ]
38
  self.tokenizer = tokenizer
39
  self.start_len = None
40
  self.input_ids = input_ids
41
 
42
+ def __call__(
43
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
44
+ ) -> bool:
45
  if self.start_len is None:
46
  self.start_len = self.input_ids.shape[1]
47
  else:
48
  for keyword_id in self.keyword_ids:
49
  if output_ids[0, -1] == keyword_id:
50
  return True
51
+ outputs = self.tokenizer.batch_decode(
52
+ output_ids[:, self.start_len :], skip_special_tokens=True
53
+ )[0]
54
  for keyword in self.keywords:
55
  if keyword in outputs:
56
  return True
 
61
 
62
  # # if output_ids[0, -1] == keyword_id:
63
  # # return True
64
+
65
  # print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len))
66
 
67
  # print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:])
model/llava/serve/cli.py CHANGED
@@ -6,14 +6,14 @@ import argparse
6
  import time
7
 
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
-
11
- from llava.conversation import conv_templates, SeparatorStyle
12
 
13
 
14
  @torch.inference_mode()
15
- def generate_stream(tokenizer, model, params, device,
16
- context_len=2048, stream_interval=2):
 
17
  """Adapted from fastchat/serve/model_worker.py::generate_stream"""
18
 
19
  prompt = params["prompt"]
@@ -30,17 +30,19 @@ def generate_stream(tokenizer, model, params, device,
30
 
31
  for i in range(max_new_tokens):
32
  if i == 0:
33
- out = model(
34
- torch.as_tensor([input_ids], device=device), use_cache=True)
35
  logits = out.logits
36
  past_key_values = out.past_key_values
37
  else:
38
  attention_mask = torch.ones(
39
- 1, past_key_values[0][0].shape[-2] + 1, device=device)
40
- out = model(input_ids=torch.as_tensor([[token]], device=device),
41
- use_cache=True,
42
- attention_mask=attention_mask,
43
- past_key_values=past_key_values)
 
 
 
44
  logits = out.logits
45
  past_key_values = out.past_key_values
46
 
@@ -84,18 +86,21 @@ def main(args):
84
  else:
85
  num_gpus = int(num_gpus)
86
  if num_gpus != 1:
87
- kwargs.update({
88
- "device_map": "auto",
89
- "max_memory": {i: "13GiB" for i in range(num_gpus)},
90
- })
 
 
91
  elif args.device == "cpu":
92
  kwargs = {}
93
  else:
94
  raise ValueError(f"Invalid device: {args.device}")
95
 
96
  tokenizer = AutoTokenizer.from_pretrained(model_name)
97
- model = AutoModelForCausalLM.from_pretrained(model_name,
98
- low_cpu_mem_usage=True, **kwargs)
 
99
 
100
  if args.device == "cuda" and num_gpus == 1:
101
  model.cuda()
@@ -126,11 +131,11 @@ def main(args):
126
  print(f"{conv.roles[1]}: ", end="", flush=True)
127
  pre = 0
128
  for outputs in generate_stream(tokenizer, model, params, args.device):
129
- outputs = outputs[len(prompt) + 1:].strip()
130
  outputs = outputs.split(" ")
131
  now = len(outputs)
132
  if now - 1 > pre:
133
- print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
134
  pre = now - 1
135
  print(" ".join(outputs[pre:]), flush=True)
136
 
 
6
  import time
7
 
8
  import torch
9
+ from llava.conversation import SeparatorStyle, conv_templates
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
11
 
12
 
13
  @torch.inference_mode()
14
+ def generate_stream(
15
+ tokenizer, model, params, device, context_len=2048, stream_interval=2
16
+ ):
17
  """Adapted from fastchat/serve/model_worker.py::generate_stream"""
18
 
19
  prompt = params["prompt"]
 
30
 
31
  for i in range(max_new_tokens):
32
  if i == 0:
33
+ out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
 
34
  logits = out.logits
35
  past_key_values = out.past_key_values
36
  else:
37
  attention_mask = torch.ones(
38
+ 1, past_key_values[0][0].shape[-2] + 1, device=device
39
+ )
40
+ out = model(
41
+ input_ids=torch.as_tensor([[token]], device=device),
42
+ use_cache=True,
43
+ attention_mask=attention_mask,
44
+ past_key_values=past_key_values,
45
+ )
46
  logits = out.logits
47
  past_key_values = out.past_key_values
48
 
 
86
  else:
87
  num_gpus = int(num_gpus)
88
  if num_gpus != 1:
89
+ kwargs.update(
90
+ {
91
+ "device_map": "auto",
92
+ "max_memory": {i: "13GiB" for i in range(num_gpus)},
93
+ }
94
+ )
95
  elif args.device == "cpu":
96
  kwargs = {}
97
  else:
98
  raise ValueError(f"Invalid device: {args.device}")
99
 
100
  tokenizer = AutoTokenizer.from_pretrained(model_name)
101
+ model = AutoModelForCausalLM.from_pretrained(
102
+ model_name, low_cpu_mem_usage=True, **kwargs
103
+ )
104
 
105
  if args.device == "cuda" and num_gpus == 1:
106
  model.cuda()
 
131
  print(f"{conv.roles[1]}: ", end="", flush=True)
132
  pre = 0
133
  for outputs in generate_stream(tokenizer, model, params, args.device):
134
+ outputs = outputs[len(prompt) + 1 :].strip()
135
  outputs = outputs.split(" ")
136
  now = len(outputs)
137
  if now - 1 > pre:
138
+ print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
139
  pre = now - 1
140
  print(" ".join(outputs[pre:]), flush=True)
141
 
model/llava/serve/controller.py CHANGED
@@ -5,23 +5,21 @@ It sends worker addresses to clients.
5
  import argparse
6
  import asyncio
7
  import dataclasses
8
- from enum import Enum, auto
9
  import json
10
  import logging
 
11
  import time
 
12
  from typing import List, Union
13
- import threading
14
 
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import StreamingResponse
17
  import numpy as np
18
  import requests
19
  import uvicorn
20
-
 
21
  from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
  from llava.utils import build_logger, server_error_msg
23
 
24
-
25
  logger = build_logger("controller", "controller.log")
26
 
27
 
@@ -61,13 +59,15 @@ class Controller:
61
  self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
 
63
  self.heart_beat_thread = threading.Thread(
64
- target=heart_beat_controller, args=(self,))
 
65
  self.heart_beat_thread.start()
66
 
67
  logger.info("Init controller")
68
 
69
- def register_worker(self, worker_name: str, check_heart_beat: bool,
70
- worker_status: dict):
 
71
  if worker_name not in self.worker_info:
72
  logger.info(f"Register a new worker: {worker_name}")
73
  else:
@@ -79,8 +79,12 @@ class Controller:
79
  return False
80
 
81
  self.worker_info[worker_name] = WorkerInfo(
82
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
- check_heart_beat, time.time())
 
 
 
 
84
 
85
  logger.info(f"Register done: {worker_name}, {worker_status}")
86
  return True
@@ -131,15 +135,13 @@ class Controller:
131
  return ""
132
  worker_speeds = worker_speeds / norm
133
  if True: # Directly return address
134
- pt = np.random.choice(np.arange(len(worker_names)),
135
- p=worker_speeds)
136
  worker_name = worker_names[pt]
137
  return worker_name
138
 
139
  # Check status before returning
140
  while True:
141
- pt = np.random.choice(np.arange(len(worker_names)),
142
- p=worker_speeds)
143
  worker_name = worker_names[pt]
144
 
145
  if self.get_worker_status(worker_name):
@@ -165,7 +167,9 @@ class Controller:
165
  min_index = np.argmin(worker_qlen)
166
  w_name = worker_names[min_index]
167
  self.worker_info[w_name].queue_length += 1
168
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
 
 
169
  return w_name
170
  else:
171
  raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
@@ -201,8 +205,12 @@ class Controller:
201
  yield json.dumps(ret).encode() + b"\0"
202
 
203
  try:
204
- response = requests.post(worker_addr + "/worker_generate_stream",
205
- json=params, stream=True, timeout=5)
 
 
 
 
206
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
  if chunk:
208
  yield chunk + b"\0"
@@ -214,7 +222,6 @@ class Controller:
214
  }
215
  yield json.dumps(ret).encode() + b"\0"
216
 
217
-
218
  # Let the controller act as a worker to achieve hierarchical
219
  # management. This can be used to connect isolated sub networks.
220
  def worker_api_get_status(self):
@@ -243,8 +250,8 @@ app = FastAPI()
243
  async def register_worker(request: Request):
244
  data = await request.json()
245
  controller.register_worker(
246
- data["worker_name"], data["check_heart_beat"],
247
- data.get("worker_status", None))
248
 
249
 
250
  @app.post("/refresh_all_workers")
@@ -268,8 +275,7 @@ async def get_worker_address(request: Request):
268
  @app.post("/receive_heart_beat")
269
  async def receive_heart_beat(request: Request):
270
  data = await request.json()
271
- exist = controller.receive_heart_beat(
272
- data["worker_name"], data["queue_length"])
273
  return {"exist": exist}
274
 
275
 
@@ -289,8 +295,12 @@ if __name__ == "__main__":
289
  parser = argparse.ArgumentParser()
290
  parser.add_argument("--host", type=str, default="localhost")
291
  parser.add_argument("--port", type=int, default=21001)
292
- parser.add_argument("--dispatch-method", type=str, choices=[
293
- "lottery", "shortest_queue"], default="shortest_queue")
 
 
 
 
294
  args = parser.parse_args()
295
  logger.info(f"args: {args}")
296
 
 
5
  import argparse
6
  import asyncio
7
  import dataclasses
 
8
  import json
9
  import logging
10
+ import threading
11
  import time
12
+ from enum import Enum, auto
13
  from typing import List, Union
 
14
 
 
 
15
  import numpy as np
16
  import requests
17
  import uvicorn
18
+ from fastapi import FastAPI, Request
19
+ from fastapi.responses import StreamingResponse
20
  from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
21
  from llava.utils import build_logger, server_error_msg
22
 
 
23
  logger = build_logger("controller", "controller.log")
24
 
25
 
 
59
  self.dispatch_method = DispatchMethod.from_str(dispatch_method)
60
 
61
  self.heart_beat_thread = threading.Thread(
62
+ target=heart_beat_controller, args=(self,)
63
+ )
64
  self.heart_beat_thread.start()
65
 
66
  logger.info("Init controller")
67
 
68
+ def register_worker(
69
+ self, worker_name: str, check_heart_beat: bool, worker_status: dict
70
+ ):
71
  if worker_name not in self.worker_info:
72
  logger.info(f"Register a new worker: {worker_name}")
73
  else:
 
79
  return False
80
 
81
  self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"],
83
+ worker_status["speed"],
84
+ worker_status["queue_length"],
85
+ check_heart_beat,
86
+ time.time(),
87
+ )
88
 
89
  logger.info(f"Register done: {worker_name}, {worker_status}")
90
  return True
 
135
  return ""
136
  worker_speeds = worker_speeds / norm
137
  if True: # Directly return address
138
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
 
139
  worker_name = worker_names[pt]
140
  return worker_name
141
 
142
  # Check status before returning
143
  while True:
144
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
 
145
  worker_name = worker_names[pt]
146
 
147
  if self.get_worker_status(worker_name):
 
167
  min_index = np.argmin(worker_qlen)
168
  w_name = worker_names[min_index]
169
  self.worker_info[w_name].queue_length += 1
170
+ logger.info(
171
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
172
+ )
173
  return w_name
174
  else:
175
  raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
 
205
  yield json.dumps(ret).encode() + b"\0"
206
 
207
  try:
208
+ response = requests.post(
209
+ worker_addr + "/worker_generate_stream",
210
+ json=params,
211
+ stream=True,
212
+ timeout=5,
213
+ )
214
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
215
  if chunk:
216
  yield chunk + b"\0"
 
222
  }
223
  yield json.dumps(ret).encode() + b"\0"
224
 
 
225
  # Let the controller act as a worker to achieve hierarchical
226
  # management. This can be used to connect isolated sub networks.
227
  def worker_api_get_status(self):
 
250
  async def register_worker(request: Request):
251
  data = await request.json()
252
  controller.register_worker(
253
+ data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
254
+ )
255
 
256
 
257
  @app.post("/refresh_all_workers")
 
275
  @app.post("/receive_heart_beat")
276
  async def receive_heart_beat(request: Request):
277
  data = await request.json()
278
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
 
279
  return {"exist": exist}
280
 
281
 
 
295
  parser = argparse.ArgumentParser()
296
  parser.add_argument("--host", type=str, default="localhost")
297
  parser.add_argument("--port", type=int, default=21001)
298
+ parser.add_argument(
299
+ "--dispatch-method",
300
+ type=str,
301
+ choices=["lottery", "shortest_queue"],
302
+ default="shortest_queue",
303
+ )
304
  args = parser.parse_args()
305
  logger.info(f"args: {args}")
306
 
model/llava/serve/gradio_css.py CHANGED
@@ -1,5 +1,4 @@
1
- code_highlight_css = (
2
- """
3
  #chatbot .hll { background-color: #ffffcc }
4
  #chatbot .c { color: #408080; font-style: italic }
5
  #chatbot .err { border: 1px solid #FF0000 }
@@ -68,6 +67,5 @@ code_highlight_css = (
68
  #chatbot .vi { color: #19177C }
69
  #chatbot .vm { color: #19177C }
70
  #chatbot .il { color: #666666 }
71
- """)
72
- #.highlight { background: #f8f8f8; }
73
-
 
1
+ code_highlight_css = """
 
2
  #chatbot .hll { background-color: #ffffcc }
3
  #chatbot .c { color: #408080; font-style: italic }
4
  #chatbot .err { border: 1px solid #FF0000 }
 
67
  #chatbot .vi { color: #19177C }
68
  #chatbot .vm { color: #19177C }
69
  #chatbot .il { color: #666666 }
70
+ """
71
+ # .highlight { background: #f8f8f8; }
 
model/llava/serve/gradio_patch.py CHANGED
@@ -50,7 +50,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
50
  warnings.warn(
51
  "The 'color_map' parameter has been deprecated.",
52
  )
53
- #self.md = utils.get_markdown_parser()
54
  self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
55
  self.select: EventListenerMethod
56
  """
@@ -113,7 +113,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
113
  ): # This happens for previously processed messages
114
  return chat_message
115
  elif isinstance(chat_message, str):
116
- #return self.md.render(chat_message)
117
  return str(self.md.convert(chat_message))
118
  else:
119
  raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
@@ -142,9 +142,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
142
  ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
143
  processed_messages.append(
144
  (
145
- #self._process_chat_messages(message_pair[0]),
146
- '<pre style="font-family: var(--font)">' +
147
- message_pair[0] + "</pre>",
 
148
  self._process_chat_messages(message_pair[1]),
149
  )
150
  )
@@ -164,5 +165,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
164
  **kwargs,
165
  )
166
  return self
167
-
168
-
 
50
  warnings.warn(
51
  "The 'color_map' parameter has been deprecated.",
52
  )
53
+ # self.md = utils.get_markdown_parser()
54
  self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
55
  self.select: EventListenerMethod
56
  """
 
113
  ): # This happens for previously processed messages
114
  return chat_message
115
  elif isinstance(chat_message, str):
116
+ # return self.md.render(chat_message)
117
  return str(self.md.convert(chat_message))
118
  else:
119
  raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
 
142
  ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
143
  processed_messages.append(
144
  (
145
+ # self._process_chat_messages(message_pair[0]),
146
+ '<pre style="font-family: var(--font)">'
147
+ + message_pair[0]
148
+ + "</pre>",
149
  self._process_chat_messages(message_pair[1]),
150
  )
151
  )
 
165
  **kwargs,
166
  )
167
  return self
 
 
model/llava/serve/gradio_web_server.py CHANGED
@@ -1,22 +1,20 @@
1
  import argparse
2
- from collections import defaultdict
3
  import datetime
 
4
  import json
5
  import os
6
  import time
 
7
 
8
  import gradio as gr
9
  import requests
10
-
11
- from llava.conversation import (default_conversation, conv_templates,
12
- SeparatorStyle)
13
  from llava.constants import LOGDIR
14
- from llava.utils import (build_logger, server_error_msg,
15
- violates_moderation, moderation_msg)
16
- from llava.serve.gradio_patch import Chatbot as grChatbot
17
  from llava.serve.gradio_css import code_highlight_css
18
- import hashlib
19
-
 
20
 
21
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
22
 
@@ -65,31 +63,33 @@ def load_demo(url_params, request: gr.Request):
65
  if "model" in url_params:
66
  model = url_params["model"]
67
  if model in models:
68
- dropdown_update = gr.Dropdown.update(
69
- value=model, visible=True)
70
 
71
  state = default_conversation.copy()
72
- return (state,
73
- dropdown_update,
74
- gr.Chatbot.update(visible=True),
75
- gr.Textbox.update(visible=True),
76
- gr.Button.update(visible=True),
77
- gr.Row.update(visible=True),
78
- gr.Accordion.update(visible=True))
 
 
79
 
80
 
81
  def load_demo_refresh_model_list(request: gr.Request):
82
  logger.info(f"load_demo. ip: {request.client.host}")
83
  models = get_model_list()
84
  state = default_conversation.copy()
85
- return (state, gr.Dropdown.update(
86
- choices=models,
87
- value=models[0] if len(models) > 0 else ""),
88
- gr.Chatbot.update(visible=True),
89
- gr.Textbox.update(visible=True),
90
- gr.Button.update(visible=True),
91
- gr.Row.update(visible=True),
92
- gr.Accordion.update(visible=True))
 
93
 
94
 
95
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
@@ -148,13 +148,14 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
148
  if flagged:
149
  state.skip_next = True
150
  return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
151
- no_change_btn,) * 5
 
152
 
153
  text = text[:1536] # Hard cut-off
154
  if image is not None:
155
  text = text[:1200] # Hard cut-off for images
156
- if '<image>' not in text:
157
- text = text + '\n<image>'
158
  text = (text, image, image_process_mode)
159
  state = default_conversation.copy()
160
  state.append_message(state.roles[0], text)
@@ -195,9 +196,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
195
  template_name = "multimodal"
196
  elif "mpt" in model_name:
197
  template_name = "mpt_text"
198
- elif "koala" in model_name: # Hardcode the condition
199
  template_name = "bair_v1"
200
- elif "v1" in model_name: # vicuna v1_1/v1_2
201
  template_name = "vicuna_v1_1"
202
  else:
203
  template_name = "v1"
@@ -208,15 +209,24 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
208
 
209
  # Query worker address
210
  controller_url = args.controller_url
211
- ret = requests.post(controller_url + "/get_worker_address",
212
- json={"model": model_name})
 
213
  worker_addr = ret.json()["address"]
214
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
215
 
216
  # No available worker
217
  if worker_addr == "":
218
  state.messages[-1][-1] = server_error_msg
219
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
 
 
 
 
 
 
 
220
  return
221
 
222
  # Construct prompt
@@ -226,7 +236,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
226
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
227
  for image, hash in zip(all_images, all_image_hash):
228
  t = datetime.datetime.now()
229
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
 
 
230
  if not os.path.isfile(filename):
231
  os.makedirs(os.path.dirname(filename), exist_ok=True)
232
  image.save(filename)
@@ -237,37 +249,56 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
237
  "prompt": prompt,
238
  "temperature": float(temperature),
239
  "max_new_tokens": min(int(max_new_tokens), 1536),
240
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
241
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
 
 
242
  }
243
  logger.info(f"==== request ====\n{pload}")
244
 
245
- pload['images'] = state.get_images()
246
 
247
  state.messages[-1][-1] = "▌"
248
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
249
 
250
  try:
251
  # Stream output
252
- response = requests.post(worker_addr + "/worker_generate_stream",
253
- headers=headers, json=pload, stream=True, timeout=10)
 
 
 
 
 
254
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
255
  if chunk:
256
  data = json.loads(chunk.decode())
257
  if data["error_code"] == 0:
258
- output = data["text"][len(prompt):].strip()
259
  output = post_process_code(output)
260
  state.messages[-1][-1] = output + "▌"
261
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
262
  else:
263
  output = data["text"] + f" (error_code: {data['error_code']})"
264
  state.messages[-1][-1] = output
265
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
 
 
 
 
 
266
  return
267
  time.sleep(0.03)
268
  except requests.exceptions.RequestException as e:
269
  state.messages[-1][-1] = server_error_msg
270
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
 
 
 
 
 
 
271
  return
272
 
273
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
@@ -289,27 +320,30 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
289
  }
290
  fout.write(json.dumps(data) + "\n")
291
 
292
- title_markdown = ("""
 
293
  # 🌋 LLaVA: Large Language and Vision Assistant
294
  [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
295
- """)
296
 
297
- tos_markdown = ("""
298
  ### Terms of use
299
  By using this service, users are required to agree to the following terms:
300
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
301
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
302
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
303
- """)
304
 
305
 
306
- learn_more_markdown = ("""
307
  ### License
308
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
309
- """)
310
 
311
 
312
- css = code_highlight_css + """
 
 
313
  pre {
314
  white-space: pre-wrap; /* Since CSS 2.1 */
315
  white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -318,11 +352,13 @@ pre {
318
  word-wrap: break-word; /* Internet Explorer 5.5+ */
319
  }
320
  """
 
321
 
322
 
323
  def build_demo(embed_mode):
324
- textbox = gr.Textbox(show_label=False,
325
- placeholder="Enter text and press ENTER", visible=False).style(container=False)
 
326
  with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo:
327
  state = gr.State()
328
 
@@ -336,26 +372,55 @@ def build_demo(embed_mode):
336
  choices=models,
337
  value=models[0] if len(models) > 0 else "",
338
  interactive=True,
339
- show_label=False).style(container=False)
 
340
 
341
  imagebox = gr.Image(type="pil")
342
  image_process_mode = gr.Radio(
343
  ["Crop", "Resize", "Pad"],
344
  value="Crop",
345
- label="Preprocess for non-square image")
 
346
 
347
  cur_dir = os.path.dirname(os.path.abspath(__file__))
348
- gr.Examples(examples=[
349
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
350
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
351
- ], inputs=[imagebox, textbox])
352
-
353
- with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
354
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
355
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  with gr.Column(scale=6):
358
- chatbot = grChatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False).style(height=550)
 
 
359
  with gr.Row():
360
  with gr.Column(scale=8):
361
  textbox.render()
@@ -365,7 +430,7 @@ def build_demo(embed_mode):
365
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
366
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
367
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
368
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
369
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
370
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
371
 
@@ -376,32 +441,82 @@ def build_demo(embed_mode):
376
 
377
  # Register listeners
378
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
379
- upvote_btn.click(upvote_last_response,
380
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
381
- downvote_btn.click(downvote_last_response,
382
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
383
- flag_btn.click(flag_last_response,
384
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
385
- regenerate_btn.click(regenerate, [state, image_process_mode],
386
- [state, chatbot, textbox, imagebox] + btn_list).then(
387
- http_bot, [state, model_selector, temperature, max_output_tokens],
388
- [state, chatbot] + btn_list)
389
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
390
-
391
- textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
392
- ).then(http_bot, [state, model_selector, temperature, max_output_tokens],
393
- [state, chatbot] + btn_list)
394
- submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
395
- ).then(http_bot, [state, model_selector, temperature, max_output_tokens],
396
- [state, chatbot] + btn_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  if args.model_list_mode == "once":
399
- demo.load(load_demo, [url_params], [state, model_selector,
400
- chatbot, textbox, submit_btn, button_row, parameter_row],
401
- _js=get_window_url_params)
 
 
 
 
 
 
 
 
 
 
 
402
  elif args.model_list_mode == "reload":
403
- demo.load(load_demo_refresh_model_list, None, [state, model_selector,
404
- chatbot, textbox, submit_btn, button_row, parameter_row])
 
 
 
 
 
 
 
 
 
 
 
405
  else:
406
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
407
 
@@ -414,8 +529,9 @@ if __name__ == "__main__":
414
  parser.add_argument("--port", type=int)
415
  parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
416
  parser.add_argument("--concurrency-count", type=int, default=8)
417
- parser.add_argument("--model-list-mode", type=str, default="once",
418
- choices=["once", "reload"])
 
419
  parser.add_argument("--share", action="store_true")
420
  parser.add_argument("--moderate", action="store_true")
421
  parser.add_argument("--embed", action="store_true")
@@ -426,6 +542,6 @@ if __name__ == "__main__":
426
 
427
  logger.info(args)
428
  demo = build_demo(args.embed)
429
- demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
430
- api_open=False).launch(
431
- server_name=args.host, server_port=args.port, share=args.share)
 
1
  import argparse
 
2
  import datetime
3
+ import hashlib
4
  import json
5
  import os
6
  import time
7
+ from collections import defaultdict
8
 
9
  import gradio as gr
10
  import requests
 
 
 
11
  from llava.constants import LOGDIR
12
+ from llava.conversation import (SeparatorStyle, conv_templates,
13
+ default_conversation)
 
14
  from llava.serve.gradio_css import code_highlight_css
15
+ from llava.serve.gradio_patch import Chatbot as grChatbot
16
+ from llava.utils import (build_logger, moderation_msg, server_error_msg,
17
+ violates_moderation)
18
 
19
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
20
 
 
63
  if "model" in url_params:
64
  model = url_params["model"]
65
  if model in models:
66
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
 
67
 
68
  state = default_conversation.copy()
69
+ return (
70
+ state,
71
+ dropdown_update,
72
+ gr.Chatbot.update(visible=True),
73
+ gr.Textbox.update(visible=True),
74
+ gr.Button.update(visible=True),
75
+ gr.Row.update(visible=True),
76
+ gr.Accordion.update(visible=True),
77
+ )
78
 
79
 
80
  def load_demo_refresh_model_list(request: gr.Request):
81
  logger.info(f"load_demo. ip: {request.client.host}")
82
  models = get_model_list()
83
  state = default_conversation.copy()
84
+ return (
85
+ state,
86
+ gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""),
87
+ gr.Chatbot.update(visible=True),
88
+ gr.Textbox.update(visible=True),
89
+ gr.Button.update(visible=True),
90
+ gr.Row.update(visible=True),
91
+ gr.Accordion.update(visible=True),
92
+ )
93
 
94
 
95
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
 
148
  if flagged:
149
  state.skip_next = True
150
  return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
151
+ no_change_btn,
152
+ ) * 5
153
 
154
  text = text[:1536] # Hard cut-off
155
  if image is not None:
156
  text = text[:1200] # Hard cut-off for images
157
+ if "<image>" not in text:
158
+ text = text + "\n<image>"
159
  text = (text, image, image_process_mode)
160
  state = default_conversation.copy()
161
  state.append_message(state.roles[0], text)
 
196
  template_name = "multimodal"
197
  elif "mpt" in model_name:
198
  template_name = "mpt_text"
199
+ elif "koala" in model_name: # Hardcode the condition
200
  template_name = "bair_v1"
201
+ elif "v1" in model_name: # vicuna v1_1/v1_2
202
  template_name = "vicuna_v1_1"
203
  else:
204
  template_name = "v1"
 
209
 
210
  # Query worker address
211
  controller_url = args.controller_url
212
+ ret = requests.post(
213
+ controller_url + "/get_worker_address", json={"model": model_name}
214
+ )
215
  worker_addr = ret.json()["address"]
216
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
217
 
218
  # No available worker
219
  if worker_addr == "":
220
  state.messages[-1][-1] = server_error_msg
221
+ yield (
222
+ state,
223
+ state.to_gradio_chatbot(),
224
+ disable_btn,
225
+ disable_btn,
226
+ disable_btn,
227
+ enable_btn,
228
+ enable_btn,
229
+ )
230
  return
231
 
232
  # Construct prompt
 
236
  all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
237
  for image, hash in zip(all_images, all_image_hash):
238
  t = datetime.datetime.now()
239
+ filename = os.path.join(
240
+ LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
241
+ )
242
  if not os.path.isfile(filename):
243
  os.makedirs(os.path.dirname(filename), exist_ok=True)
244
  image.save(filename)
 
249
  "prompt": prompt,
250
  "temperature": float(temperature),
251
  "max_new_tokens": min(int(max_new_tokens), 1536),
252
+ "stop": state.sep
253
+ if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
254
+ else state.sep2,
255
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
256
  }
257
  logger.info(f"==== request ====\n{pload}")
258
 
259
+ pload["images"] = state.get_images()
260
 
261
  state.messages[-1][-1] = "▌"
262
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
263
 
264
  try:
265
  # Stream output
266
+ response = requests.post(
267
+ worker_addr + "/worker_generate_stream",
268
+ headers=headers,
269
+ json=pload,
270
+ stream=True,
271
+ timeout=10,
272
+ )
273
  for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
274
  if chunk:
275
  data = json.loads(chunk.decode())
276
  if data["error_code"] == 0:
277
+ output = data["text"][len(prompt) :].strip()
278
  output = post_process_code(output)
279
  state.messages[-1][-1] = output + "▌"
280
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
281
  else:
282
  output = data["text"] + f" (error_code: {data['error_code']})"
283
  state.messages[-1][-1] = output
284
+ yield (state, state.to_gradio_chatbot()) + (
285
+ disable_btn,
286
+ disable_btn,
287
+ disable_btn,
288
+ enable_btn,
289
+ enable_btn,
290
+ )
291
  return
292
  time.sleep(0.03)
293
  except requests.exceptions.RequestException as e:
294
  state.messages[-1][-1] = server_error_msg
295
+ yield (state, state.to_gradio_chatbot()) + (
296
+ disable_btn,
297
+ disable_btn,
298
+ disable_btn,
299
+ enable_btn,
300
+ enable_btn,
301
+ )
302
  return
303
 
304
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
 
320
  }
321
  fout.write(json.dumps(data) + "\n")
322
 
323
+
324
+ title_markdown = """
325
  # 🌋 LLaVA: Large Language and Vision Assistant
326
  [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
327
+ """
328
 
329
+ tos_markdown = """
330
  ### Terms of use
331
  By using this service, users are required to agree to the following terms:
332
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
333
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
334
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
335
+ """
336
 
337
 
338
+ learn_more_markdown = """
339
  ### License
340
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
341
+ """
342
 
343
 
344
+ css = (
345
+ code_highlight_css
346
+ + """
347
  pre {
348
  white-space: pre-wrap; /* Since CSS 2.1 */
349
  white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
 
352
  word-wrap: break-word; /* Internet Explorer 5.5+ */
353
  }
354
  """
355
+ )
356
 
357
 
358
  def build_demo(embed_mode):
359
+ textbox = gr.Textbox(
360
+ show_label=False, placeholder="Enter text and press ENTER", visible=False
361
+ ).style(container=False)
362
  with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo:
363
  state = gr.State()
364
 
 
372
  choices=models,
373
  value=models[0] if len(models) > 0 else "",
374
  interactive=True,
375
+ show_label=False,
376
+ ).style(container=False)
377
 
378
  imagebox = gr.Image(type="pil")
379
  image_process_mode = gr.Radio(
380
  ["Crop", "Resize", "Pad"],
381
  value="Crop",
382
+ label="Preprocess for non-square image",
383
+ )
384
 
385
  cur_dir = os.path.dirname(os.path.abspath(__file__))
386
+ gr.Examples(
387
+ examples=[
388
+ [
389
+ f"{cur_dir}/examples/extreme_ironing.jpg",
390
+ "What is unusual about this image?",
391
+ ],
392
+ [
393
+ f"{cur_dir}/examples/waterview.jpg",
394
+ "What are the things I should be cautious about when I visit here?",
395
+ ],
396
+ ],
397
+ inputs=[imagebox, textbox],
398
+ )
399
+
400
+ with gr.Accordion(
401
+ "Parameters", open=False, visible=False
402
+ ) as parameter_row:
403
+ temperature = gr.Slider(
404
+ minimum=0.0,
405
+ maximum=1.0,
406
+ value=0.2,
407
+ step=0.1,
408
+ interactive=True,
409
+ label="Temperature",
410
+ )
411
+ max_output_tokens = gr.Slider(
412
+ minimum=0,
413
+ maximum=1024,
414
+ value=512,
415
+ step=64,
416
+ interactive=True,
417
+ label="Max output tokens",
418
+ )
419
 
420
  with gr.Column(scale=6):
421
+ chatbot = grChatbot(
422
+ elem_id="chatbot", label="LLaVA Chatbot", visible=False
423
+ ).style(height=550)
424
  with gr.Row():
425
  with gr.Column(scale=8):
426
  textbox.render()
 
430
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
431
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
432
  flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
433
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
434
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
435
  clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
436
 
 
441
 
442
  # Register listeners
443
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
444
+ upvote_btn.click(
445
+ upvote_last_response,
446
+ [state, model_selector],
447
+ [textbox, upvote_btn, downvote_btn, flag_btn],
448
+ )
449
+ downvote_btn.click(
450
+ downvote_last_response,
451
+ [state, model_selector],
452
+ [textbox, upvote_btn, downvote_btn, flag_btn],
453
+ )
454
+ flag_btn.click(
455
+ flag_last_response,
456
+ [state, model_selector],
457
+ [textbox, upvote_btn, downvote_btn, flag_btn],
458
+ )
459
+ regenerate_btn.click(
460
+ regenerate,
461
+ [state, image_process_mode],
462
+ [state, chatbot, textbox, imagebox] + btn_list,
463
+ ).then(
464
+ http_bot,
465
+ [state, model_selector, temperature, max_output_tokens],
466
+ [state, chatbot] + btn_list,
467
+ )
468
+ clear_btn.click(
469
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
470
+ )
471
+
472
+ textbox.submit(
473
+ add_text,
474
+ [state, textbox, imagebox, image_process_mode],
475
+ [state, chatbot, textbox, imagebox] + btn_list,
476
+ ).then(
477
+ http_bot,
478
+ [state, model_selector, temperature, max_output_tokens],
479
+ [state, chatbot] + btn_list,
480
+ )
481
+ submit_btn.click(
482
+ add_text,
483
+ [state, textbox, imagebox, image_process_mode],
484
+ [state, chatbot, textbox, imagebox] + btn_list,
485
+ ).then(
486
+ http_bot,
487
+ [state, model_selector, temperature, max_output_tokens],
488
+ [state, chatbot] + btn_list,
489
+ )
490
 
491
  if args.model_list_mode == "once":
492
+ demo.load(
493
+ load_demo,
494
+ [url_params],
495
+ [
496
+ state,
497
+ model_selector,
498
+ chatbot,
499
+ textbox,
500
+ submit_btn,
501
+ button_row,
502
+ parameter_row,
503
+ ],
504
+ _js=get_window_url_params,
505
+ )
506
  elif args.model_list_mode == "reload":
507
+ demo.load(
508
+ load_demo_refresh_model_list,
509
+ None,
510
+ [
511
+ state,
512
+ model_selector,
513
+ chatbot,
514
+ textbox,
515
+ submit_btn,
516
+ button_row,
517
+ parameter_row,
518
+ ],
519
+ )
520
  else:
521
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
522
 
 
529
  parser.add_argument("--port", type=int)
530
  parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
531
  parser.add_argument("--concurrency-count", type=int, default=8)
532
+ parser.add_argument(
533
+ "--model-list-mode", type=str, default="once", choices=["once", "reload"]
534
+ )
535
  parser.add_argument("--share", action="store_true")
536
  parser.add_argument("--moderate", action="store_true")
537
  parser.add_argument("--embed", action="store_true")
 
542
 
543
  logger.info(args)
544
  demo = build_demo(args.embed)
545
+ demo.queue(
546
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
547
+ ).launch(server_name=args.host, server_port=args.port, share=args.share)
model/llava/serve/model_worker.py CHANGED
@@ -4,25 +4,23 @@ A model worker executes the model.
4
  import argparse
5
  import asyncio
6
  import dataclasses
7
- import logging
8
  import json
9
- import time
10
- from typing import List, Union
11
  import threading
 
12
  import uuid
 
 
13
 
14
- from fastapi import FastAPI, Request, BackgroundTasks
15
- from fastapi.responses import StreamingResponse
16
  import requests
17
- from transformers import AutoTokenizer, AutoModelForCausalLM
18
  import torch
19
  import uvicorn
20
- from functools import partial
21
-
22
  from llava.constants import WORKER_HEART_BEAT_INTERVAL
23
- from llava.utils import (build_logger, server_error_msg,
24
- pretty_print_semaphore)
25
  from llava.model import *
 
 
26
 
27
  GB = 1 << 30
28
 
@@ -40,7 +38,6 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
40
 
41
 
42
  def heart_beat_worker(controller):
43
-
44
  while True:
45
  time.sleep(WORKER_HEART_BEAT_INTERVAL)
46
  controller.send_heart_beat()
@@ -56,38 +53,66 @@ def load_model(model_path, model_name, num_gpus):
56
  }
57
 
58
  tokenizer = AutoTokenizer.from_pretrained(model_path)
59
- if 'llava' in model_name.lower():
60
- if 'mpt' in model_name.lower():
61
- model = LlavaMPTForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
 
 
62
  else:
63
- model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
64
- elif 'mpt' in model_name.lower():
65
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
 
 
 
 
 
 
 
 
66
  else:
67
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
 
 
68
 
69
  image_processor = None
70
 
71
- if 'llava' in model_name.lower():
72
  from transformers import CLIPImageProcessor, CLIPVisionModel
73
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
 
 
 
74
 
75
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
76
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
77
  if mm_use_im_start_end:
78
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
 
 
79
 
80
  vision_tower = model.get_model().vision_tower[0]
81
- if vision_tower.device.type == 'meta':
82
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
 
 
 
 
83
  model.get_model().vision_tower[0] = vision_tower
84
  else:
85
- vision_tower.to(device='cuda', dtype=torch.float16)
86
  vision_config = vision_tower.config
87
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
 
 
88
  vision_config.use_im_start_end = mm_use_im_start_end
89
  if mm_use_im_start_end:
90
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
91
 
92
  if num_gpus == 1:
93
  model.cuda()
@@ -101,11 +126,17 @@ def load_model(model_path, model_name, num_gpus):
101
 
102
 
103
  class ModelWorker:
104
- def __init__(self, controller_addr, worker_addr,
105
- worker_id, no_register,
106
- model_path, model_name,
107
- keep_aspect_ratio,
108
- num_gpus):
 
 
 
 
 
 
109
  self.controller_addr = controller_addr
110
  self.worker_addr = worker_addr
111
  self.worker_id = worker_id
@@ -113,7 +144,7 @@ class ModelWorker:
113
  model_path = model_path[:-1]
114
  if model_name is None:
115
  model_paths = model_path.split("/")
116
- if model_paths[-1].startswith('checkpoint-'):
117
  self.model_name = model_paths[-2] + "_" + model_paths[-1]
118
  else:
119
  self.model_name = model_paths[-1]
@@ -123,13 +154,15 @@ class ModelWorker:
123
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
124
  self.keep_aspect_ratio = keep_aspect_ratio
125
  self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
126
- model_path, self.model_name, num_gpus)
127
- self.is_multimodal = 'llava' in model_path.lower()
 
128
 
129
  if not no_register:
130
  self.register_to_controller()
131
  self.heart_beat_thread = threading.Thread(
132
- target=heart_beat_worker, args=(self,))
 
133
  self.heart_beat_thread.start()
134
 
135
  def register_to_controller(self):
@@ -139,23 +172,30 @@ class ModelWorker:
139
  data = {
140
  "worker_name": self.worker_addr,
141
  "check_heart_beat": True,
142
- "worker_status": self.get_status()
143
  }
144
  r = requests.post(url, json=data)
145
  assert r.status_code == 200
146
 
147
  def send_heart_beat(self):
148
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
149
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
150
- f"global_counter: {global_counter}")
 
 
151
 
152
  url = self.controller_addr + "/receive_heart_beat"
153
 
154
  while True:
155
  try:
156
- ret = requests.post(url, json={
157
- "worker_name": self.worker_addr,
158
- "queue_length": self.get_queue_length()}, timeout=5)
 
 
 
 
 
159
  exist = ret.json()["exist"]
160
  break
161
  except requests.exceptions.RequestException as e:
@@ -169,8 +209,15 @@ class ModelWorker:
169
  if model_semaphore is None:
170
  return 0
171
  else:
172
- return args.limit_model_concurrency - model_semaphore._value + (len(
173
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
 
 
 
 
 
 
 
174
 
175
  def get_status(self):
176
  return {
@@ -181,20 +228,30 @@ class ModelWorker:
181
 
182
  @torch.inference_mode()
183
  def generate_stream(self, params):
184
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
 
 
 
 
185
 
186
  prompt = params["prompt"]
187
  ori_prompt = prompt
188
  images = params.get("images", None)
189
  if images is not None and len(images) > 0 and self.is_multimodal:
190
- from PIL import Image
191
- from io import BytesIO
192
  import base64
 
 
 
 
193
  assert type(images) is list
194
  if len(images) > 0:
195
  # assert len(images) == 1, "Only support one image for now"
196
- images = [Image.open(BytesIO(base64.b64decode(image))) for image in images]
197
- assert len(images) == prompt.count(DEFAULT_IMAGE_TOKEN), "Number of images does not match number of <image> tokens in prompt"
 
 
 
 
198
 
199
  if self.keep_aspect_ratio:
200
  new_images = []
@@ -203,21 +260,40 @@ class ModelWorker:
203
  aspect_ratio = max_hw / min_hw
204
  max_len, min_len = 448, 224
205
  shortest_edge = int(min(max_len / aspect_ratio, min_len))
206
- image = image_processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
207
- new_images.append(image.to(self.model.device, dtype=torch.float16))
 
 
 
 
 
 
 
208
  # replace the image token with the image patch token in the prompt (each occurrence)
209
- cur_token_len = (image.shape[1]//14) * (image.shape[2]//14)
210
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
211
- if getattr(self.model.config, 'mm_use_im_start_end', False):
212
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
 
 
 
 
213
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
214
  images = new_images
215
  else:
216
- images = image_processor(images, return_tensors='pt')['pixel_values']
 
 
217
  images = images.to(self.model.device, dtype=torch.float16)
218
- replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 256 # HACK: 256 is the max image token length hacked
219
- if getattr(self.model.config, 'mm_use_im_start_end', False):
220
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
 
 
 
 
 
 
221
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
222
  else:
223
  images = None
@@ -249,18 +325,20 @@ class ModelWorker:
249
  for i in range(max_new_tokens):
250
  if i == 0:
251
  out = model(
252
- torch.as_tensor([input_ids]).cuda(),
253
- use_cache=True,
254
- **image_args)
255
  logits = out.logits
256
  past_key_values = out.past_key_values
257
  else:
258
  attention_mask = torch.ones(
259
- 1, past_key_values[0][0].shape[-2] + 1, device="cuda")
260
- out = model(input_ids=torch.as_tensor([[token]], device="cuda"),
261
- use_cache=True,
262
- attention_mask=attention_mask,
263
- past_key_values=past_key_values)
 
 
 
264
  logits = out.logits
265
  past_key_values = out.past_key_values
266
 
@@ -342,7 +420,9 @@ async def generate_stream(request: Request):
342
  worker.send_heart_beat()
343
  generator = worker.generate_stream_gate(params)
344
  background_tasks = BackgroundTasks()
345
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
 
 
346
  return StreamingResponse(generator, background=background_tasks)
347
 
348
 
@@ -355,13 +435,17 @@ if __name__ == "__main__":
355
  parser = argparse.ArgumentParser()
356
  parser.add_argument("--host", type=str, default="localhost")
357
  parser.add_argument("--port", type=int, default=21002)
358
- parser.add_argument("--worker-address", type=str,
359
- default="http://localhost:21002")
360
- parser.add_argument("--controller-address", type=str,
361
- default="http://localhost:21001")
362
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
363
  parser.add_argument("--model-name", type=str)
364
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
 
 
 
 
365
  parser.add_argument("--keep-aspect-ratio", action="store_true")
366
  parser.add_argument("--num-gpus", type=int, default=1)
367
  parser.add_argument("--limit-model-concurrency", type=int, default=5)
@@ -371,14 +455,18 @@ if __name__ == "__main__":
371
  logger.info(f"args: {args}")
372
 
373
  if args.multi_modal:
374
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
375
-
376
- worker = ModelWorker(args.controller_address,
377
- args.worker_address,
378
- worker_id,
379
- args.no_register,
380
- args.model_path,
381
- args.model_name,
382
- args.keep_aspect_ratio,
383
- args.num_gpus)
 
 
 
 
384
  uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
4
  import argparse
5
  import asyncio
6
  import dataclasses
 
7
  import json
8
+ import logging
 
9
  import threading
10
+ import time
11
  import uuid
12
+ from functools import partial
13
+ from typing import List, Union
14
 
 
 
15
  import requests
 
16
  import torch
17
  import uvicorn
18
+ from fastapi import BackgroundTasks, FastAPI, Request
19
+ from fastapi.responses import StreamingResponse
20
  from llava.constants import WORKER_HEART_BEAT_INTERVAL
 
 
21
  from llava.model import *
22
+ from llava.utils import build_logger, pretty_print_semaphore, server_error_msg
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
 
25
  GB = 1 << 30
26
 
 
38
 
39
 
40
  def heart_beat_worker(controller):
 
41
  while True:
42
  time.sleep(WORKER_HEART_BEAT_INTERVAL)
43
  controller.send_heart_beat()
 
53
  }
54
 
55
  tokenizer = AutoTokenizer.from_pretrained(model_path)
56
+ if "llava" in model_name.lower():
57
+ if "mpt" in model_name.lower():
58
+ model = LlavaMPTForCausalLM.from_pretrained(
59
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
60
+ )
61
  else:
62
+ model = LlavaLlamaForCausalLM.from_pretrained(
63
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
64
+ )
65
+ elif "mpt" in model_name.lower():
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_path,
68
+ torch_dtype=torch.float16,
69
+ low_cpu_mem_usage=True,
70
+ trust_remote_code=True,
71
+ **kwargs,
72
+ )
73
  else:
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
76
+ )
77
 
78
  image_processor = None
79
 
80
+ if "llava" in model_name.lower():
81
  from transformers import CLIPImageProcessor, CLIPVisionModel
82
+
83
+ image_processor = CLIPImageProcessor.from_pretrained(
84
+ model.config.mm_vision_tower, torch_dtype=torch.float16
85
+ )
86
 
87
  mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
88
  tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
89
  if mm_use_im_start_end:
90
+ tokenizer.add_tokens(
91
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
92
+ )
93
 
94
  vision_tower = model.get_model().vision_tower[0]
95
+ if vision_tower.device.type == "meta":
96
+ vision_tower = CLIPVisionModel.from_pretrained(
97
+ vision_tower.config._name_or_path,
98
+ torch_dtype=torch.float16,
99
+ low_cpu_mem_usage=True,
100
+ ).cuda()
101
  model.get_model().vision_tower[0] = vision_tower
102
  else:
103
+ vision_tower.to(device="cuda", dtype=torch.float16)
104
  vision_config = vision_tower.config
105
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
106
+ [DEFAULT_IMAGE_PATCH_TOKEN]
107
+ )[0]
108
  vision_config.use_im_start_end = mm_use_im_start_end
109
  if mm_use_im_start_end:
110
+ (
111
+ vision_config.im_start_token,
112
+ vision_config.im_end_token,
113
+ ) = tokenizer.convert_tokens_to_ids(
114
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
115
+ )
116
 
117
  if num_gpus == 1:
118
  model.cuda()
 
126
 
127
 
128
  class ModelWorker:
129
+ def __init__(
130
+ self,
131
+ controller_addr,
132
+ worker_addr,
133
+ worker_id,
134
+ no_register,
135
+ model_path,
136
+ model_name,
137
+ keep_aspect_ratio,
138
+ num_gpus,
139
+ ):
140
  self.controller_addr = controller_addr
141
  self.worker_addr = worker_addr
142
  self.worker_id = worker_id
 
144
  model_path = model_path[:-1]
145
  if model_name is None:
146
  model_paths = model_path.split("/")
147
+ if model_paths[-1].startswith("checkpoint-"):
148
  self.model_name = model_paths[-2] + "_" + model_paths[-1]
149
  else:
150
  self.model_name = model_paths[-1]
 
154
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
155
  self.keep_aspect_ratio = keep_aspect_ratio
156
  self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
157
+ model_path, self.model_name, num_gpus
158
+ )
159
+ self.is_multimodal = "llava" in model_path.lower()
160
 
161
  if not no_register:
162
  self.register_to_controller()
163
  self.heart_beat_thread = threading.Thread(
164
+ target=heart_beat_worker, args=(self,)
165
+ )
166
  self.heart_beat_thread.start()
167
 
168
  def register_to_controller(self):
 
172
  data = {
173
  "worker_name": self.worker_addr,
174
  "check_heart_beat": True,
175
+ "worker_status": self.get_status(),
176
  }
177
  r = requests.post(url, json=data)
178
  assert r.status_code == 200
179
 
180
  def send_heart_beat(self):
181
+ logger.info(
182
+ f"Send heart beat. Models: {[self.model_name]}. "
183
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
184
+ f"global_counter: {global_counter}"
185
+ )
186
 
187
  url = self.controller_addr + "/receive_heart_beat"
188
 
189
  while True:
190
  try:
191
+ ret = requests.post(
192
+ url,
193
+ json={
194
+ "worker_name": self.worker_addr,
195
+ "queue_length": self.get_queue_length(),
196
+ },
197
+ timeout=5,
198
+ )
199
  exist = ret.json()["exist"]
200
  break
201
  except requests.exceptions.RequestException as e:
 
209
  if model_semaphore is None:
210
  return 0
211
  else:
212
+ return (
213
+ args.limit_model_concurrency
214
+ - model_semaphore._value
215
+ + (
216
+ len(model_semaphore._waiters)
217
+ if model_semaphore._waiters is not None
218
+ else 0
219
+ )
220
+ )
221
 
222
  def get_status(self):
223
  return {
 
228
 
229
  @torch.inference_mode()
230
  def generate_stream(self, params):
231
+ tokenizer, model, image_processor = (
232
+ self.tokenizer,
233
+ self.model,
234
+ self.image_processor,
235
+ )
236
 
237
  prompt = params["prompt"]
238
  ori_prompt = prompt
239
  images = params.get("images", None)
240
  if images is not None and len(images) > 0 and self.is_multimodal:
 
 
241
  import base64
242
+ from io import BytesIO
243
+
244
+ from PIL import Image
245
+
246
  assert type(images) is list
247
  if len(images) > 0:
248
  # assert len(images) == 1, "Only support one image for now"
249
+ images = [
250
+ Image.open(BytesIO(base64.b64decode(image))) for image in images
251
+ ]
252
+ assert len(images) == prompt.count(
253
+ DEFAULT_IMAGE_TOKEN
254
+ ), "Number of images does not match number of <image> tokens in prompt"
255
 
256
  if self.keep_aspect_ratio:
257
  new_images = []
 
260
  aspect_ratio = max_hw / min_hw
261
  max_len, min_len = 448, 224
262
  shortest_edge = int(min(max_len / aspect_ratio, min_len))
263
+ image = image_processor.preprocess(
264
+ image,
265
+ return_tensors="pt",
266
+ do_center_crop=False,
267
+ size={"shortest_edge": shortest_edge},
268
+ )["pixel_values"][0]
269
+ new_images.append(
270
+ image.to(self.model.device, dtype=torch.float16)
271
+ )
272
  # replace the image token with the image patch token in the prompt (each occurrence)
273
+ cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14)
274
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
275
+ if getattr(self.model.config, "mm_use_im_start_end", False):
276
+ replace_token = (
277
+ DEFAULT_IM_START_TOKEN
278
+ + replace_token
279
+ + DEFAULT_IM_END_TOKEN
280
+ )
281
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
282
  images = new_images
283
  else:
284
+ images = image_processor(images, return_tensors="pt")[
285
+ "pixel_values"
286
+ ]
287
  images = images.to(self.model.device, dtype=torch.float16)
288
+ replace_token = (
289
+ DEFAULT_IMAGE_PATCH_TOKEN * 256
290
+ ) # HACK: 256 is the max image token length hacked
291
+ if getattr(self.model.config, "mm_use_im_start_end", False):
292
+ replace_token = (
293
+ DEFAULT_IM_START_TOKEN
294
+ + replace_token
295
+ + DEFAULT_IM_END_TOKEN
296
+ )
297
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
298
  else:
299
  images = None
 
325
  for i in range(max_new_tokens):
326
  if i == 0:
327
  out = model(
328
+ torch.as_tensor([input_ids]).cuda(), use_cache=True, **image_args
329
+ )
 
330
  logits = out.logits
331
  past_key_values = out.past_key_values
332
  else:
333
  attention_mask = torch.ones(
334
+ 1, past_key_values[0][0].shape[-2] + 1, device="cuda"
335
+ )
336
+ out = model(
337
+ input_ids=torch.as_tensor([[token]], device="cuda"),
338
+ use_cache=True,
339
+ attention_mask=attention_mask,
340
+ past_key_values=past_key_values,
341
+ )
342
  logits = out.logits
343
  past_key_values = out.past_key_values
344
 
 
420
  worker.send_heart_beat()
421
  generator = worker.generate_stream_gate(params)
422
  background_tasks = BackgroundTasks()
423
+ background_tasks.add_task(
424
+ partial(release_model_semaphore, fn=worker.send_heart_beat)
425
+ )
426
  return StreamingResponse(generator, background=background_tasks)
427
 
428
 
 
435
  parser = argparse.ArgumentParser()
436
  parser.add_argument("--host", type=str, default="localhost")
437
  parser.add_argument("--port", type=int, default=21002)
438
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
439
+ parser.add_argument(
440
+ "--controller-address", type=str, default="http://localhost:21001"
441
+ )
442
  parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
443
  parser.add_argument("--model-name", type=str)
444
+ parser.add_argument(
445
+ "--multi-modal",
446
+ action="store_true",
447
+ help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.",
448
+ )
449
  parser.add_argument("--keep-aspect-ratio", action="store_true")
450
  parser.add_argument("--num-gpus", type=int, default=1)
451
  parser.add_argument("--limit-model-concurrency", type=int, default=5)
 
455
  logger.info(f"args: {args}")
456
 
457
  if args.multi_modal:
458
+ logger.warning(
459
+ "Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path."
460
+ )
461
+
462
+ worker = ModelWorker(
463
+ args.controller_address,
464
+ args.worker_address,
465
+ worker_id,
466
+ args.no_register,
467
+ args.model_path,
468
+ args.model_name,
469
+ args.keep_aspect_ratio,
470
+ args.num_gpus,
471
+ )
472
  uvicorn.run(app, host=args.host, port=args.port, log_level="info")
model/llava/serve/test_message.py CHANGED
@@ -2,7 +2,6 @@ import argparse
2
  import json
3
 
4
  import requests
5
-
6
  from llava.conversation import default_conversation
7
 
8
 
@@ -17,8 +16,9 @@ def main():
17
  models.sort()
18
  print(f"Models: {models}")
19
 
20
- ret = requests.post(controller_addr + "/get_worker_address",
21
- json={"model": args.model_name})
 
22
  worker_addr = ret.json()["address"]
23
  print(f"worker_addr: {worker_addr}")
24
 
@@ -37,11 +37,17 @@ def main():
37
  "temperature": 0.7,
38
  "stop": conv.sep,
39
  }
40
- response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
- json=pload, stream=True)
 
 
 
 
42
 
43
  print(prompt.replace(conv.sep, "\n"), end="")
44
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
 
 
45
  if chunk:
46
  data = json.loads(chunk.decode("utf-8"))
47
  output = data["text"].split(conv.sep)[-1]
@@ -51,12 +57,15 @@ def main():
51
 
52
  if __name__ == "__main__":
53
  parser = argparse.ArgumentParser()
54
- parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
 
 
55
  parser.add_argument("--worker-address", type=str)
56
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
  parser.add_argument("--max-new-tokens", type=int, default=32)
58
- parser.add_argument("--message", type=str, default=
59
- "Tell me a story with more than 1000 words.")
 
60
  args = parser.parse_args()
61
 
62
  main()
 
2
  import json
3
 
4
  import requests
 
5
  from llava.conversation import default_conversation
6
 
7
 
 
16
  models.sort()
17
  print(f"Models: {models}")
18
 
19
+ ret = requests.post(
20
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
21
+ )
22
  worker_addr = ret.json()["address"]
23
  print(f"worker_addr: {worker_addr}")
24
 
 
37
  "temperature": 0.7,
38
  "stop": conv.sep,
39
  }
40
+ response = requests.post(
41
+ worker_addr + "/worker_generate_stream",
42
+ headers=headers,
43
+ json=pload,
44
+ stream=True,
45
+ )
46
 
47
  print(prompt.replace(conv.sep, "\n"), end="")
48
+ for chunk in response.iter_lines(
49
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
50
+ ):
51
  if chunk:
52
  data = json.loads(chunk.decode("utf-8"))
53
  output = data["text"].split(conv.sep)[-1]
 
57
 
58
  if __name__ == "__main__":
59
  parser = argparse.ArgumentParser()
60
+ parser.add_argument(
61
+ "--controller-address", type=str, default="http://localhost:21001"
62
+ )
63
  parser.add_argument("--worker-address", type=str)
64
  parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
65
  parser.add_argument("--max-new-tokens", type=int, default=32)
66
+ parser.add_argument(
67
+ "--message", type=str, default="Tell me a story with more than 1000 words."
68
+ )
69
  args = parser.parse_args()
70
 
71
  main()
model/llava/train/llama_flash_attn_monkey_patch.py CHANGED
@@ -2,15 +2,13 @@
2
  from typing import List, Optional, Tuple
3
 
4
  import torch
5
- from torch import nn
6
-
7
  import transformers
8
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9
-
10
  from einops import rearrange
11
-
12
  from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
13
- from flash_attn.bert_padding import unpad_input, pad_input
 
 
14
 
15
  def forward(
16
  self,
@@ -19,20 +17,28 @@ def forward(
19
  attention_mask: Optional[torch.Tensor] = None,
20
  output_attentions: bool = False,
21
  use_cache: bool = False,
22
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
23
- Optional[Tuple[torch.Tensor]]]:
24
  """Input shape: Batch x Time x Channel
25
-
26
  attention_mask: [bsz, q_len]
27
  """
28
  bsz, q_len, _ = hidden_states.size()
29
 
30
- query_states = self.q_proj(hidden_states).view(
31
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32
- key_states = self.k_proj(hidden_states).view(
33
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
34
- value_states = self.v_proj(hidden_states).view(
35
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
36
  # [bsz, q_len, nh, hd]
37
  # [bsz, nh, q_len, hd]
38
 
@@ -42,11 +48,9 @@ def forward(
42
  offset = past_key_value[0].shape[-2]
43
  kv_seq_len += offset
44
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
45
- query_states, key_states = apply_rotary_pos_emb(query_states,
46
- key_states,
47
- cos,
48
- sin,
49
- offset=offset)
50
  # [bsz, nh, t, hd]
51
  assert not output_attentions, "output_attentions is not supported"
52
  assert not use_cache, "use_cache is not supported"
@@ -56,47 +60,55 @@ def forward(
56
  # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
57
 
58
  # transform the data into the format required by flash attention
59
- qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
60
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
 
 
61
  # We have disabled _prepare_decoder_attention_mask in LlamaModel
62
  # the attention_mask should be the same as the key_padding_mask
63
  key_padding_mask = attention_mask
64
 
65
-
66
  if key_padding_mask is None:
67
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
68
  max_s = q_len
69
- cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
70
- device=qkv.device)
 
71
  output = flash_attn_unpadded_qkvpacked_func(
72
- qkv, cu_q_lens, max_s, 0.0,
73
- softmax_scale=None, causal=True
74
  )
75
- output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
76
  else:
77
  nheads = qkv.shape[-2]
78
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
79
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
80
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
 
 
81
  output_unpad = flash_attn_unpadded_qkvpacked_func(
82
- x_unpad, cu_q_lens, max_s, 0.0,
83
- softmax_scale=None, causal=True
 
 
 
 
 
 
84
  )
85
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
86
- indices, bsz, q_len),
87
- 'b s (h d) -> b s h d', h=nheads)
88
- return self.o_proj(rearrange(output,
89
- 'b s h d -> b s (h d)')), None, None
90
 
91
 
92
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
93
  # requires the attention mask to be the same as the key_padding_mask
94
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
95
- inputs_embeds, past_key_values_length):
 
96
  # [bsz, seq_len]
97
  return attention_mask
98
 
99
 
100
  def replace_llama_attn_with_flash_attn():
101
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
 
 
102
  transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
 
2
  from typing import List, Optional, Tuple
3
 
4
  import torch
 
 
5
  import transformers
 
 
6
  from einops import rearrange
7
+ from flash_attn.bert_padding import pad_input, unpad_input
8
  from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
9
+ from torch import nn
10
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
11
+
12
 
13
  def forward(
14
  self,
 
17
  attention_mask: Optional[torch.Tensor] = None,
18
  output_attentions: bool = False,
19
  use_cache: bool = False,
20
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
21
  """Input shape: Batch x Time x Channel
22
+
23
  attention_mask: [bsz, q_len]
24
  """
25
  bsz, q_len, _ = hidden_states.size()
26
 
27
+ query_states = (
28
+ self.q_proj(hidden_states)
29
+ .view(bsz, q_len, self.num_heads, self.head_dim)
30
+ .transpose(1, 2)
31
+ )
32
+ key_states = (
33
+ self.k_proj(hidden_states)
34
+ .view(bsz, q_len, self.num_heads, self.head_dim)
35
+ .transpose(1, 2)
36
+ )
37
+ value_states = (
38
+ self.v_proj(hidden_states)
39
+ .view(bsz, q_len, self.num_heads, self.head_dim)
40
+ .transpose(1, 2)
41
+ )
42
  # [bsz, q_len, nh, hd]
43
  # [bsz, nh, q_len, hd]
44
 
 
48
  offset = past_key_value[0].shape[-2]
49
  kv_seq_len += offset
50
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
51
+ query_states, key_states = apply_rotary_pos_emb(
52
+ query_states, key_states, cos, sin, offset=offset
53
+ )
 
 
54
  # [bsz, nh, t, hd]
55
  assert not output_attentions, "output_attentions is not supported"
56
  assert not use_cache, "use_cache is not supported"
 
60
  # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
61
 
62
  # transform the data into the format required by flash attention
63
+ qkv = torch.stack(
64
+ [query_states, key_states, value_states], dim=2
65
+ ) # [bsz, nh, 3, q_len, hd]
66
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
67
  # We have disabled _prepare_decoder_attention_mask in LlamaModel
68
  # the attention_mask should be the same as the key_padding_mask
69
  key_padding_mask = attention_mask
70
 
 
71
  if key_padding_mask is None:
72
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
73
  max_s = q_len
74
+ cu_q_lens = torch.arange(
75
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
76
+ )
77
  output = flash_attn_unpadded_qkvpacked_func(
78
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 
79
  )
80
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
81
  else:
82
  nheads = qkv.shape[-2]
83
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
84
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
85
+ x_unpad = rearrange(
86
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
87
+ )
88
  output_unpad = flash_attn_unpadded_qkvpacked_func(
89
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
90
+ )
91
+ output = rearrange(
92
+ pad_input(
93
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
94
+ ),
95
+ "b s (h d) -> b s h d",
96
+ h=nheads,
97
  )
98
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
 
 
 
 
99
 
100
 
101
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
102
  # requires the attention mask to be the same as the key_padding_mask
103
+ def _prepare_decoder_attention_mask(
104
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
105
+ ):
106
  # [bsz, seq_len]
107
  return attention_mask
108
 
109
 
110
  def replace_llama_attn_with_flash_attn():
111
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
112
+ _prepare_decoder_attention_mask
113
+ )
114
  transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
model/llava/train/llava_trainer.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
 
 
2
  import torch
3
  import torch.nn as nn
4
-
5
  from transformers import Trainer
6
- from typing import Dict, Optional, Sequence
7
 
8
 
9
  def unwrap_model(model: nn.Module) -> nn.Module:
@@ -21,9 +21,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
21
 
22
 
23
  class LLaVATrainer(Trainer):
24
-
25
  def _save(self, output_dir: Optional[str] = None, state_dict=None):
26
- if getattr(self.args, 'tune_mm_mlp_adapter', False):
27
  # Save the model
28
  _state_dict = state_dict
29
  if _state_dict is None:
@@ -32,18 +31,23 @@ class LLaVATrainer(Trainer):
32
  _state_dict = model_to_save.state_dict()
33
 
34
  weight_to_save = {}
35
- keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
36
  for k, v in _state_dict.items():
37
  if any(key_match in k for key_match in keys_to_match):
38
  weight_to_save[k] = v
39
 
40
- current_folder = output_dir.split('/')[-1]
41
  parent_folder = os.path.dirname(output_dir)
42
- if current_folder.startswith('checkpoint-'):
43
  mm_projector_folder = os.path.join(parent_folder, "mm_projector")
44
  os.makedirs(mm_projector_folder, exist_ok=True)
45
- torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
 
 
 
46
  else:
47
- torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
 
 
48
 
49
  super(LLaVATrainer, self)._save(output_dir, state_dict)
 
1
  import os
2
+ from typing import Dict, Optional, Sequence
3
+
4
  import torch
5
  import torch.nn as nn
 
6
  from transformers import Trainer
 
7
 
8
 
9
  def unwrap_model(model: nn.Module) -> nn.Module:
 
21
 
22
 
23
  class LLaVATrainer(Trainer):
 
24
  def _save(self, output_dir: Optional[str] = None, state_dict=None):
25
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
26
  # Save the model
27
  _state_dict = state_dict
28
  if _state_dict is None:
 
31
  _state_dict = model_to_save.state_dict()
32
 
33
  weight_to_save = {}
34
+ keys_to_match = ["mm_projector", "embed_tokens", "embed_in"]
35
  for k, v in _state_dict.items():
36
  if any(key_match in k for key_match in keys_to_match):
37
  weight_to_save[k] = v
38
 
39
+ current_folder = output_dir.split("/")[-1]
40
  parent_folder = os.path.dirname(output_dir)
41
+ if current_folder.startswith("checkpoint-"):
42
  mm_projector_folder = os.path.join(parent_folder, "mm_projector")
43
  os.makedirs(mm_projector_folder, exist_ok=True)
44
+ torch.save(
45
+ weight_to_save,
46
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
47
+ )
48
  else:
49
+ torch.save(
50
+ weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
51
+ )
52
 
53
  super(LLaVATrainer, self)._save(output_dir, state_dict)
model/llava/train/train.py CHANGED
@@ -14,25 +14,22 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
17
- import os
18
  import copy
19
- from dataclasses import dataclass, field
20
  import json
21
  import logging
 
22
  import pathlib
 
23
  from typing import Dict, Optional, Sequence
24
 
25
  import torch
26
-
27
  import transformers
28
- from torch.utils.data import Dataset
29
- from llava.train.llava_trainer import LLaVATrainer
30
-
31
  from llava import conversation as conversation_lib
32
  from llava.model import *
33
-
34
  from PIL import Image
35
- import torch.nn as nn
36
 
37
  # TODO: import and use code from ../data/dataset.py
38
 
@@ -54,21 +51,24 @@ class ModelArguments:
54
  freeze_backbone: bool = field(default=False)
55
  tune_mm_mlp_adapter: bool = field(default=False)
56
  vision_tower: Optional[str] = field(default=None)
57
- mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
 
 
58
  pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
59
  mm_use_im_start_end: bool = field(default=False)
60
 
61
 
62
  @dataclass
63
  class DataArguments:
64
- data_path: str = field(default=None,
65
- metadata={"help": "Path to the training data."})
 
66
  lazy_preprocess: bool = False
67
  is_multimodal: bool = False
68
  sep_image_conv_front: bool = False
69
  image_token_len: int = 0
70
  image_folder: Optional[str] = field(default=None)
71
- image_aspect_ratio: str = 'square'
72
 
73
 
74
  @dataclass
@@ -81,21 +81,16 @@ class TrainingArguments(transformers.TrainingArguments):
81
  model_max_length: int = field(
82
  default=512,
83
  metadata={
84
- "help":
85
- "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
86
  },
87
  )
88
 
89
 
90
- def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
91
- output_dir: str):
92
  """Collects the state dict and dump to disk."""
93
  state_dict = trainer.model.state_dict()
94
  if trainer.args.should_save:
95
- cpu_state_dict = {
96
- key: value.cpu()
97
- for key, value in state_dict.items()
98
- }
99
  del state_dict
100
  trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
101
 
@@ -117,16 +112,19 @@ def smart_tokenizer_and_embedding_resize(
117
  output_embeddings = model.get_output_embeddings().weight.data
118
 
119
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
120
- dim=0, keepdim=True)
 
121
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
122
- dim=0, keepdim=True)
 
123
 
124
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
125
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
126
 
127
 
128
- def _tokenize_fn(strings: Sequence[str],
129
- tokenizer: transformers.PreTrainedTokenizer) -> Dict:
 
130
  """Tokenize a list of strings."""
131
  tokenized_list = [
132
  tokenizer(
@@ -135,11 +133,10 @@ def _tokenize_fn(strings: Sequence[str],
135
  padding="longest",
136
  max_length=tokenizer.model_max_length,
137
  truncation=True,
138
- ) for text in strings
139
- ]
140
- input_ids = labels = [
141
- tokenized.input_ids[0] for tokenized in tokenized_list
142
  ]
 
143
  input_ids_lens = labels_lens = [
144
  tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
145
  for tokenized in tokenized_list
@@ -159,7 +156,7 @@ def _mask_targets(target, tokenized_lens, speakers):
159
  target[:cur_idx] = IGNORE_INDEX
160
  for tokenized_len, speaker in zip(tokenized_lens, speakers):
161
  if speaker == "human":
162
- target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
163
  cur_idx += tokenized_len
164
 
165
 
@@ -175,9 +172,10 @@ def _add_speaker_and_signal(header, source, get_conversation=True):
175
  elif from_str.lower() == "gpt":
176
  from_str = conversation_lib.default_conversation.roles[1]
177
  else:
178
- from_str = 'unknown'
179
- sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
180
- sentence["value"] + END_SIGNAL)
 
181
  if get_conversation:
182
  conversation += sentence["value"]
183
  conversation += BEGIN_SIGNAL
@@ -189,22 +187,34 @@ def preprocess_multimodal(
189
  multimodal_cfg: dict,
190
  cur_token_len: int,
191
  ) -> Dict:
192
- is_multimodal = multimodal_cfg['is_multimodal']
193
  # image_token_len = multimodal_cfg['image_token_len']
194
  image_token_len = cur_token_len
195
  if not is_multimodal:
196
  return sources
197
 
198
  for source in sources:
199
- if multimodal_cfg['sep_image_conv_front']:
200
- assert DEFAULT_IMAGE_TOKEN in source[0]['value']
201
- source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
202
- source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
 
 
 
 
 
 
 
 
203
  for sentence in source:
204
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
205
- if multimodal_cfg['use_im_start_end']:
206
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
207
- sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
 
 
 
 
208
 
209
  return sources
210
 
@@ -279,6 +289,7 @@ def preprocess_v1(
279
  labels=targets,
280
  )
281
 
 
282
  def preprocess_mpt(
283
  sources,
284
  tokenizer: transformers.PreTrainedTokenizer,
@@ -317,9 +328,11 @@ def preprocess_mpt(
317
  total_len = int(target.ne(tokenizer.pad_token_id).sum())
318
 
319
  rounds = conversation.split(conv.sep)
320
- re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
321
  for conv_idx in range(3, len(rounds), 2):
322
- re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
 
 
323
  cur_len = 0
324
  target[:cur_len] = IGNORE_INDEX
325
  for i, rou in enumerate(re_rounds):
@@ -330,7 +343,9 @@ def preprocess_mpt(
330
  if len(parts) != 2:
331
  break
332
  parts[0] += sep
333
- round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
 
 
334
  instruction_len = len(tokenizer(parts[0]).input_ids)
335
  target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
336
 
@@ -377,8 +392,9 @@ def preprocess(
377
  input_ids = conversations_tokenized["input_ids"]
378
  targets = copy.deepcopy(input_ids)
379
  for target, source in zip(targets, sources):
380
- tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
381
- tokenizer)["input_ids_lens"]
 
382
  speakers = [sentence["from"] for sentence in source]
383
  _mask_targets(target, tokenized_lens, speakers)
384
 
@@ -388,8 +404,7 @@ def preprocess(
388
  class SupervisedDataset(Dataset):
389
  """Dataset for supervised fine-tuning."""
390
 
391
- def __init__(self, data_path: str,
392
- tokenizer: transformers.PreTrainedTokenizer):
393
  super(SupervisedDataset, self).__init__()
394
  logging.warning("Loading data...")
395
  list_data_dict = json.load(open(data_path, "r"))
@@ -411,9 +426,12 @@ class SupervisedDataset(Dataset):
411
  class LazySupervisedDataset(Dataset):
412
  """Dataset for supervised fine-tuning."""
413
 
414
- def __init__(self, data_path: str,
415
- tokenizer: transformers.PreTrainedTokenizer,
416
- multimodal_cfg: dict):
 
 
 
417
  super(LazySupervisedDataset, self).__init__()
418
  logging.warning("Loading data...")
419
  list_data_dict = json.load(open(data_path, "r"))
@@ -431,54 +449,74 @@ class LazySupervisedDataset(Dataset):
431
  if isinstance(i, int):
432
  sources = [sources]
433
  assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
434
- if 'image' in sources[0]:
435
- image_file = self.list_data_dict[i]['image']
436
- image_folder = self.multimodal_cfg['image_folder']
437
- processor = self.multimodal_cfg['image_processor']
438
- image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
439
- if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
440
  max_hw, min_hw = max(image.size), min(image.size)
441
  aspect_ratio = max_hw / min_hw
442
  max_len, min_len = 448, 224
443
  shortest_edge = int(min(max_len / aspect_ratio, min_len))
444
- image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
445
- elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
 
 
 
 
 
 
446
  def expand2square(pil_img, background_color):
447
  width, height = pil_img.size
448
  if width == height:
449
  return pil_img
450
  elif width > height:
451
- result = Image.new(pil_img.mode, (width, width), background_color)
 
 
452
  result.paste(pil_img, (0, (width - height) // 2))
453
  return result
454
  else:
455
- result = Image.new(pil_img.mode, (height, height), background_color)
 
 
456
  result.paste(pil_img, ((height - width) // 2, 0))
457
  return result
458
- image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
459
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
 
 
 
 
 
460
  else:
461
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
462
- cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size
 
 
 
 
463
  sources = preprocess_multimodal(
464
  copy.deepcopy([e["conversations"] for e in sources]),
465
- self.multimodal_cfg, cur_token_len)
 
 
466
  else:
467
  sources = copy.deepcopy([e["conversations"] for e in sources])
468
- data_dict = preprocess(
469
- sources,
470
- self.tokenizer)
471
  if isinstance(i, int):
472
- data_dict = dict(input_ids=data_dict["input_ids"][0],
473
- labels=data_dict["labels"][0])
 
474
 
475
  # image exist in the data
476
- if 'image' in self.list_data_dict[i]:
477
- data_dict['image'] = image
478
- elif self.multimodal_cfg['is_multimodal']:
479
  # image does not exist in the data, but the model is multimodal
480
- crop_size = self.multimodal_cfg['image_processor'].crop_size
481
- data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
482
  return data_dict
483
 
484
 
@@ -489,59 +527,65 @@ class DataCollatorForSupervisedDataset(object):
489
  tokenizer: transformers.PreTrainedTokenizer
490
 
491
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
492
- input_ids, labels = tuple([instance[key] for instance in instances]
493
- for key in ("input_ids", "labels"))
 
494
  input_ids = torch.nn.utils.rnn.pad_sequence(
495
- input_ids,
496
- batch_first=True,
497
- padding_value=self.tokenizer.pad_token_id)
498
- labels = torch.nn.utils.rnn.pad_sequence(labels,
499
- batch_first=True,
500
- padding_value=IGNORE_INDEX)
501
  batch = dict(
502
  input_ids=input_ids,
503
  labels=labels,
504
  attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
505
  )
506
 
507
- if 'image' in instances[0]:
508
- images = [instance['image'] for instance in instances]
509
  if all(x is not None and x.shape == images[0].shape for x in images):
510
- batch['images'] = torch.stack(images)
511
  else:
512
- batch['images'] = images
513
 
514
  return batch
515
 
516
 
517
- def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
518
- data_args) -> Dict:
 
519
  """Make dataset and collator for supervised fine-tuning."""
520
- dataset_cls = (LazySupervisedDataset
521
- if data_args.lazy_preprocess else SupervisedDataset)
522
- train_dataset = dataset_cls(tokenizer=tokenizer,
523
- data_path=data_args.data_path,
524
- multimodal_cfg=dict(
525
- is_multimodal=data_args.is_multimodal,
526
- sep_image_conv_front=data_args.sep_image_conv_front,
527
- image_token_len=data_args.image_token_len,
528
- image_folder=data_args.image_folder,
529
- image_aspect_ratio=data_args.image_aspect_ratio,
530
- use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False),
531
- image_processor=getattr(data_args, 'image_processor', None)))
 
 
 
 
532
  data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
533
- return dict(train_dataset=train_dataset,
534
- eval_dataset=None,
535
- data_collator=data_collator)
536
 
537
 
538
  def train():
539
  parser = transformers.HfArgumentParser(
540
- (ModelArguments, DataArguments, TrainingArguments))
 
541
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
542
 
543
  if model_args.vision_tower is not None:
544
- if 'mpt' in model_args.model_name_or_path:
545
  model = LlavaMPTForCausalLM.from_pretrained(
546
  model_args.model_name_or_path,
547
  cache_dir=training_args.cache_dir,
@@ -561,12 +605,12 @@ def train():
561
  if model_args.freeze_backbone:
562
  model.model.requires_grad_(False)
563
 
564
- if 'mpt' in model_args.model_name_or_path:
565
  tokenizer = transformers.AutoTokenizer.from_pretrained(
566
  model_args.model_name_or_path,
567
  cache_dir=training_args.cache_dir,
568
  model_max_length=training_args.model_max_length,
569
- padding_side="right"
570
  )
571
  else:
572
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -585,23 +629,29 @@ def train():
585
  model=model,
586
  )
587
  if "llama" in model_args.model_name_or_path:
588
- tokenizer.add_special_tokens({
589
- "eos_token": DEFAULT_EOS_TOKEN,
590
- "bos_token": DEFAULT_BOS_TOKEN,
591
- "unk_token": DEFAULT_UNK_TOKEN,
592
- })
 
 
593
  else:
594
  tokenizer.pad_token = tokenizer.unk_token
595
  if "mpt" in model_args.model_name_or_path:
596
- conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
 
 
597
  else:
598
- conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
 
 
599
 
600
  if model_args.vision_tower is not None:
601
  model_vision_dict = model.get_model().initialize_vision_modules(
602
  vision_tower=model_args.vision_tower,
603
  mm_vision_select_layer=model_args.mm_vision_select_layer,
604
- pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter
605
  )
606
  dtype = torch.float32
607
  if training_args.fp16:
@@ -609,13 +659,15 @@ def train():
609
  if training_args.bf16:
610
  dtype = torch.bfloat16
611
  model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device)
612
- vision_config = model_vision_dict['vision_config']
613
 
614
- data_args.image_token_len = model_vision_dict['image_token_len']
615
- data_args.image_processor = model_vision_dict['image_processor']
616
  data_args.is_multimodal = True
617
 
618
- model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
 
 
619
  if model_args.tune_mm_mlp_adapter:
620
  model.requires_grad_(False)
621
  for p in model.get_model().mm_projector.parameters():
@@ -626,45 +678,66 @@ def train():
626
  for p in model.get_model().mm_projector.parameters():
627
  p.requires_grad = False
628
 
629
- model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
630
- vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end
 
 
 
 
631
  model.config.sep_image_conv_front = data_args.sep_image_conv_front
632
- model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device,
633
- tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
 
 
 
 
 
634
 
635
  params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
636
  if len(params_no_grad) > 0:
637
  if training_args.fsdp is not None and len(training_args.fsdp) > 0:
638
  if len(params_no_grad) < 10:
639
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
 
 
 
 
640
  else:
641
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
642
- print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
643
- print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
 
 
 
 
 
 
 
 
 
 
 
644
 
645
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
646
  def patch_FSDP_use_orig_params(func):
647
  def wrap_func(*args, **kwargs):
648
- use_orig_params = kwargs.pop('use_orig_params', True)
649
  return func(*args, **kwargs, use_orig_params=use_orig_params)
 
650
  return wrap_func
651
 
652
  FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
653
 
654
- data_module = make_supervised_data_module(tokenizer=tokenizer,
655
- data_args=data_args)
656
- trainer = LLaVATrainer(model=model,
657
- tokenizer=tokenizer,
658
- args=training_args,
659
- **data_module)
660
 
661
  if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
662
  trainer.train(resume_from_checkpoint=True)
663
  else:
664
  trainer.train()
665
  trainer.save_state()
666
- safe_save_model_for_hf_trainer(trainer=trainer,
667
- output_dir=training_args.output_dir)
668
 
669
 
670
  if __name__ == "__main__":
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
 
17
  import copy
 
18
  import json
19
  import logging
20
+ import os
21
  import pathlib
22
+ from dataclasses import dataclass, field
23
  from typing import Dict, Optional, Sequence
24
 
25
  import torch
26
+ import torch.nn as nn
27
  import transformers
 
 
 
28
  from llava import conversation as conversation_lib
29
  from llava.model import *
30
+ from llava.train.llava_trainer import LLaVATrainer
31
  from PIL import Image
32
+ from torch.utils.data import Dataset
33
 
34
  # TODO: import and use code from ../data/dataset.py
35
 
 
51
  freeze_backbone: bool = field(default=False)
52
  tune_mm_mlp_adapter: bool = field(default=False)
53
  vision_tower: Optional[str] = field(default=None)
54
+ mm_vision_select_layer: Optional[int] = field(
55
+ default=-1
56
+ ) # default to the last layer
57
  pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
58
  mm_use_im_start_end: bool = field(default=False)
59
 
60
 
61
  @dataclass
62
  class DataArguments:
63
+ data_path: str = field(
64
+ default=None, metadata={"help": "Path to the training data."}
65
+ )
66
  lazy_preprocess: bool = False
67
  is_multimodal: bool = False
68
  sep_image_conv_front: bool = False
69
  image_token_len: int = 0
70
  image_folder: Optional[str] = field(default=None)
71
+ image_aspect_ratio: str = "square"
72
 
73
 
74
  @dataclass
 
81
  model_max_length: int = field(
82
  default=512,
83
  metadata={
84
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
 
85
  },
86
  )
87
 
88
 
89
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
 
90
  """Collects the state dict and dump to disk."""
91
  state_dict = trainer.model.state_dict()
92
  if trainer.args.should_save:
93
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
 
 
 
94
  del state_dict
95
  trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
96
 
 
112
  output_embeddings = model.get_output_embeddings().weight.data
113
 
114
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
115
+ dim=0, keepdim=True
116
+ )
117
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
118
+ dim=0, keepdim=True
119
+ )
120
 
121
  input_embeddings[-num_new_tokens:] = input_embeddings_avg
122
  output_embeddings[-num_new_tokens:] = output_embeddings_avg
123
 
124
 
125
+ def _tokenize_fn(
126
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
127
+ ) -> Dict:
128
  """Tokenize a list of strings."""
129
  tokenized_list = [
130
  tokenizer(
 
133
  padding="longest",
134
  max_length=tokenizer.model_max_length,
135
  truncation=True,
136
+ )
137
+ for text in strings
 
 
138
  ]
139
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
140
  input_ids_lens = labels_lens = [
141
  tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
142
  for tokenized in tokenized_list
 
156
  target[:cur_idx] = IGNORE_INDEX
157
  for tokenized_len, speaker in zip(tokenized_lens, speakers):
158
  if speaker == "human":
159
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
160
  cur_idx += tokenized_len
161
 
162
 
 
172
  elif from_str.lower() == "gpt":
173
  from_str = conversation_lib.default_conversation.roles[1]
174
  else:
175
+ from_str = "unknown"
176
+ sentence["value"] = (
177
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
178
+ )
179
  if get_conversation:
180
  conversation += sentence["value"]
181
  conversation += BEGIN_SIGNAL
 
187
  multimodal_cfg: dict,
188
  cur_token_len: int,
189
  ) -> Dict:
190
+ is_multimodal = multimodal_cfg["is_multimodal"]
191
  # image_token_len = multimodal_cfg['image_token_len']
192
  image_token_len = cur_token_len
193
  if not is_multimodal:
194
  return sources
195
 
196
  for source in sources:
197
+ if multimodal_cfg["sep_image_conv_front"]:
198
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
199
+ source[0]["value"] = (
200
+ source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
201
+ )
202
+ source[0]["value"] = (
203
+ DEFAULT_IMAGE_TOKEN
204
+ + conversation_lib.default_conversation.sep
205
+ + conversation_lib.default_conversation.roles[0]
206
+ + ": "
207
+ + source[0]["value"]
208
+ )
209
  for sentence in source:
210
  replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
211
+ if multimodal_cfg["use_im_start_end"]:
212
+ replace_token = (
213
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
214
+ )
215
+ sentence["value"] = sentence["value"].replace(
216
+ DEFAULT_IMAGE_TOKEN, replace_token
217
+ )
218
 
219
  return sources
220
 
 
289
  labels=targets,
290
  )
291
 
292
+
293
  def preprocess_mpt(
294
  sources,
295
  tokenizer: transformers.PreTrainedTokenizer,
 
328
  total_len = int(target.ne(tokenizer.pad_token_id).sum())
329
 
330
  rounds = conversation.split(conv.sep)
331
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
332
  for conv_idx in range(3, len(rounds), 2):
333
+ re_rounds.append(
334
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
335
+ ) # user + gpt
336
  cur_len = 0
337
  target[:cur_len] = IGNORE_INDEX
338
  for i, rou in enumerate(re_rounds):
 
343
  if len(parts) != 2:
344
  break
345
  parts[0] += sep
346
+ round_len = len(tokenizer(rou).input_ids) + len(
347
+ tokenizer(conv.sep).input_ids
348
+ )
349
  instruction_len = len(tokenizer(parts[0]).input_ids)
350
  target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
351
 
 
392
  input_ids = conversations_tokenized["input_ids"]
393
  targets = copy.deepcopy(input_ids)
394
  for target, source in zip(targets, sources):
395
+ tokenized_lens = _tokenize_fn(
396
+ [header] + [s["value"] for s in source], tokenizer
397
+ )["input_ids_lens"]
398
  speakers = [sentence["from"] for sentence in source]
399
  _mask_targets(target, tokenized_lens, speakers)
400
 
 
404
  class SupervisedDataset(Dataset):
405
  """Dataset for supervised fine-tuning."""
406
 
407
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
 
408
  super(SupervisedDataset, self).__init__()
409
  logging.warning("Loading data...")
410
  list_data_dict = json.load(open(data_path, "r"))
 
426
  class LazySupervisedDataset(Dataset):
427
  """Dataset for supervised fine-tuning."""
428
 
429
+ def __init__(
430
+ self,
431
+ data_path: str,
432
+ tokenizer: transformers.PreTrainedTokenizer,
433
+ multimodal_cfg: dict,
434
+ ):
435
  super(LazySupervisedDataset, self).__init__()
436
  logging.warning("Loading data...")
437
  list_data_dict = json.load(open(data_path, "r"))
 
449
  if isinstance(i, int):
450
  sources = [sources]
451
  assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
452
+ if "image" in sources[0]:
453
+ image_file = self.list_data_dict[i]["image"]
454
+ image_folder = self.multimodal_cfg["image_folder"]
455
+ processor = self.multimodal_cfg["image_processor"]
456
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
457
+ if self.multimodal_cfg["image_aspect_ratio"] == "keep":
458
  max_hw, min_hw = max(image.size), min(image.size)
459
  aspect_ratio = max_hw / min_hw
460
  max_len, min_len = 448, 224
461
  shortest_edge = int(min(max_len / aspect_ratio, min_len))
462
+ image = processor.preprocess(
463
+ image,
464
+ return_tensors="pt",
465
+ do_center_crop=False,
466
+ size={"shortest_edge": shortest_edge},
467
+ )["pixel_values"][0]
468
+ elif self.multimodal_cfg["image_aspect_ratio"] == "pad":
469
+
470
  def expand2square(pil_img, background_color):
471
  width, height = pil_img.size
472
  if width == height:
473
  return pil_img
474
  elif width > height:
475
+ result = Image.new(
476
+ pil_img.mode, (width, width), background_color
477
+ )
478
  result.paste(pil_img, (0, (width - height) // 2))
479
  return result
480
  else:
481
+ result = Image.new(
482
+ pil_img.mode, (height, height), background_color
483
+ )
484
  result.paste(pil_img, ((height - width) // 2, 0))
485
  return result
486
+
487
+ image = expand2square(
488
+ image, tuple(int(x * 255) for x in processor.image_mean)
489
+ )
490
+ image = processor.preprocess(image, return_tensors="pt")[
491
+ "pixel_values"
492
+ ][0]
493
  else:
494
+ image = processor.preprocess(image, return_tensors="pt")[
495
+ "pixel_values"
496
+ ][0]
497
+ cur_token_len = (image.shape[1] // 14) * (
498
+ image.shape[2] // 14
499
+ ) # FIXME: 14 is hardcoded patch size
500
  sources = preprocess_multimodal(
501
  copy.deepcopy([e["conversations"] for e in sources]),
502
+ self.multimodal_cfg,
503
+ cur_token_len,
504
+ )
505
  else:
506
  sources = copy.deepcopy([e["conversations"] for e in sources])
507
+ data_dict = preprocess(sources, self.tokenizer)
 
 
508
  if isinstance(i, int):
509
+ data_dict = dict(
510
+ input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
511
+ )
512
 
513
  # image exist in the data
514
+ if "image" in self.list_data_dict[i]:
515
+ data_dict["image"] = image
516
+ elif self.multimodal_cfg["is_multimodal"]:
517
  # image does not exist in the data, but the model is multimodal
518
+ crop_size = self.multimodal_cfg["image_processor"].crop_size
519
+ data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
520
  return data_dict
521
 
522
 
 
527
  tokenizer: transformers.PreTrainedTokenizer
528
 
529
  def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
530
+ input_ids, labels = tuple(
531
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
532
+ )
533
  input_ids = torch.nn.utils.rnn.pad_sequence(
534
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
535
+ )
536
+ labels = torch.nn.utils.rnn.pad_sequence(
537
+ labels, batch_first=True, padding_value=IGNORE_INDEX
538
+ )
 
539
  batch = dict(
540
  input_ids=input_ids,
541
  labels=labels,
542
  attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
543
  )
544
 
545
+ if "image" in instances[0]:
546
+ images = [instance["image"] for instance in instances]
547
  if all(x is not None and x.shape == images[0].shape for x in images):
548
+ batch["images"] = torch.stack(images)
549
  else:
550
+ batch["images"] = images
551
 
552
  return batch
553
 
554
 
555
+ def make_supervised_data_module(
556
+ tokenizer: transformers.PreTrainedTokenizer, data_args
557
+ ) -> Dict:
558
  """Make dataset and collator for supervised fine-tuning."""
559
+ dataset_cls = (
560
+ LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
561
+ )
562
+ train_dataset = dataset_cls(
563
+ tokenizer=tokenizer,
564
+ data_path=data_args.data_path,
565
+ multimodal_cfg=dict(
566
+ is_multimodal=data_args.is_multimodal,
567
+ sep_image_conv_front=data_args.sep_image_conv_front,
568
+ image_token_len=data_args.image_token_len,
569
+ image_folder=data_args.image_folder,
570
+ image_aspect_ratio=data_args.image_aspect_ratio,
571
+ use_im_start_end=getattr(data_args, "mm_use_im_start_end", False),
572
+ image_processor=getattr(data_args, "image_processor", None),
573
+ ),
574
+ )
575
  data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
576
+ return dict(
577
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
578
+ )
579
 
580
 
581
  def train():
582
  parser = transformers.HfArgumentParser(
583
+ (ModelArguments, DataArguments, TrainingArguments)
584
+ )
585
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
586
 
587
  if model_args.vision_tower is not None:
588
+ if "mpt" in model_args.model_name_or_path:
589
  model = LlavaMPTForCausalLM.from_pretrained(
590
  model_args.model_name_or_path,
591
  cache_dir=training_args.cache_dir,
 
605
  if model_args.freeze_backbone:
606
  model.model.requires_grad_(False)
607
 
608
+ if "mpt" in model_args.model_name_or_path:
609
  tokenizer = transformers.AutoTokenizer.from_pretrained(
610
  model_args.model_name_or_path,
611
  cache_dir=training_args.cache_dir,
612
  model_max_length=training_args.model_max_length,
613
+ padding_side="right",
614
  )
615
  else:
616
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
629
  model=model,
630
  )
631
  if "llama" in model_args.model_name_or_path:
632
+ tokenizer.add_special_tokens(
633
+ {
634
+ "eos_token": DEFAULT_EOS_TOKEN,
635
+ "bos_token": DEFAULT_BOS_TOKEN,
636
+ "unk_token": DEFAULT_UNK_TOKEN,
637
+ }
638
+ )
639
  else:
640
  tokenizer.pad_token = tokenizer.unk_token
641
  if "mpt" in model_args.model_name_or_path:
642
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
643
+ "mpt"
644
+ ]
645
  else:
646
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
647
+ "vicuna_v1_1"
648
+ ]
649
 
650
  if model_args.vision_tower is not None:
651
  model_vision_dict = model.get_model().initialize_vision_modules(
652
  vision_tower=model_args.vision_tower,
653
  mm_vision_select_layer=model_args.mm_vision_select_layer,
654
+ pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
655
  )
656
  dtype = torch.float32
657
  if training_args.fp16:
 
659
  if training_args.bf16:
660
  dtype = torch.bfloat16
661
  model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device)
662
+ vision_config = model_vision_dict["vision_config"]
663
 
664
+ data_args.image_token_len = model_vision_dict["image_token_len"]
665
+ data_args.image_processor = model_vision_dict["image_processor"]
666
  data_args.is_multimodal = True
667
 
668
+ model.config.tune_mm_mlp_adapter = (
669
+ training_args.tune_mm_mlp_adapter
670
+ ) = model_args.tune_mm_mlp_adapter
671
  if model_args.tune_mm_mlp_adapter:
672
  model.requires_grad_(False)
673
  for p in model.get_model().mm_projector.parameters():
 
678
  for p in model.get_model().mm_projector.parameters():
679
  p.requires_grad = False
680
 
681
+ model.config.mm_use_im_start_end = (
682
+ data_args.mm_use_im_start_end
683
+ ) = model_args.mm_use_im_start_end
684
+ vision_config.use_im_start_end = (
685
+ training_args.use_im_start_end
686
+ ) = model_args.mm_use_im_start_end
687
  model.config.sep_image_conv_front = data_args.sep_image_conv_front
688
+ model.initialize_vision_tokenizer(
689
+ mm_use_im_start_end=model_args.mm_use_im_start_end,
690
+ tokenizer=tokenizer,
691
+ device=training_args.device,
692
+ tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter,
693
+ pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
694
+ )
695
 
696
  params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
697
  if len(params_no_grad) > 0:
698
  if training_args.fsdp is not None and len(training_args.fsdp) > 0:
699
  if len(params_no_grad) < 10:
700
+ print(
701
+ "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}".format(
702
+ len(params_no_grad), params_no_grad
703
+ )
704
+ )
705
  else:
706
+ print(
707
+ "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)".format(
708
+ len(params_no_grad), ", ".join(params_no_grad[:10])
709
+ )
710
+ )
711
+ print(
712
+ "[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental."
713
+ )
714
+ print(
715
+ "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
716
+ )
717
+
718
+ from torch.distributed.fsdp.fully_sharded_data_parallel import \
719
+ FullyShardedDataParallel as FSDP
720
 
 
721
  def patch_FSDP_use_orig_params(func):
722
  def wrap_func(*args, **kwargs):
723
+ use_orig_params = kwargs.pop("use_orig_params", True)
724
  return func(*args, **kwargs, use_orig_params=use_orig_params)
725
+
726
  return wrap_func
727
 
728
  FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
729
 
730
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
731
+ trainer = LLaVATrainer(
732
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
733
+ )
 
 
734
 
735
  if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
736
  trainer.train(resume_from_checkpoint=True)
737
  else:
738
  trainer.train()
739
  trainer.save_state()
740
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
 
741
 
742
 
743
  if __name__ == "__main__":
model/llava/train/train_mem.py CHANGED
@@ -3,7 +3,8 @@
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
- from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
 
7
 
8
  replace_llama_attn_with_flash_attn()
9
 
 
3
  # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4
 
5
  # Need to call this before importing transformers.
6
+ from llava.train.llama_flash_attn_monkey_patch import \
7
+ replace_llama_attn_with_flash_attn
8
 
9
  replace_llama_attn_with_flash_attn()
10
 
model/llava/utils.py CHANGED
@@ -5,11 +5,14 @@ import os
5
  import sys
6
 
7
  import requests
8
-
9
  from llava.constants import LOGDIR
10
 
11
- server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
- moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
 
 
 
 
13
 
14
  handler = None
15
 
@@ -47,7 +50,8 @@ def build_logger(logger_name, logger_filename):
47
  os.makedirs(LOGDIR, exist_ok=True)
48
  filename = os.path.join(LOGDIR, logger_filename)
49
  handler = logging.handlers.TimedRotatingFileHandler(
50
- filename, when='D', utc=True)
 
51
  handler.setFormatter(formatter)
52
 
53
  for name, item in logging.root.manager.loggerDict.items():
@@ -61,33 +65,34 @@ class StreamToLogger(object):
61
  """
62
  Fake file-like stream object that redirects writes to a logger instance.
63
  """
 
64
  def __init__(self, logger, log_level=logging.INFO):
65
  self.terminal = sys.stdout
66
  self.logger = logger
67
  self.log_level = log_level
68
- self.linebuf = ''
69
 
70
  def __getattr__(self, attr):
71
  return getattr(self.terminal, attr)
72
 
73
  def write(self, buf):
74
  temp_linebuf = self.linebuf + buf
75
- self.linebuf = ''
76
  for line in temp_linebuf.splitlines(True):
77
  # From the io.TextIOWrapper docs:
78
  # On output, if newline is None, any '\n' characters written
79
  # are translated to the system default line separator.
80
  # By default sys.stdout.write() expects '\n' newlines and then
81
  # translates them so this is still cross platform.
82
- if line[-1] == '\n':
83
  self.logger.log(self.log_level, line.rstrip())
84
  else:
85
  self.linebuf += line
86
 
87
  def flush(self):
88
- if self.linebuf != '':
89
  self.logger.log(self.log_level, self.linebuf.rstrip())
90
- self.linebuf = ''
91
 
92
 
93
  def disable_torch_init():
@@ -95,6 +100,7 @@ def disable_torch_init():
95
  Disable the redundant torch default initialization to accelerate model creation.
96
  """
97
  import torch
 
98
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
 
@@ -104,8 +110,10 @@ def violates_moderation(text):
104
  Check whether the text violates OpenAI moderation API.
105
  """
106
  url = "https://api.openai.com/v1/moderations"
107
- headers = {"Content-Type": "application/json",
108
- "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
 
 
109
  text = text.replace("\n", "")
110
  data = "{" + '"input": ' + f'"{text}"' + "}"
111
  data = data.encode("utf-8")
 
5
  import sys
6
 
7
  import requests
 
8
  from llava.constants import LOGDIR
9
 
10
+ server_error_msg = (
11
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ )
13
+ moderation_msg = (
14
+ "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15
+ )
16
 
17
  handler = None
18
 
 
50
  os.makedirs(LOGDIR, exist_ok=True)
51
  filename = os.path.join(LOGDIR, logger_filename)
52
  handler = logging.handlers.TimedRotatingFileHandler(
53
+ filename, when="D", utc=True
54
+ )
55
  handler.setFormatter(formatter)
56
 
57
  for name, item in logging.root.manager.loggerDict.items():
 
65
  """
66
  Fake file-like stream object that redirects writes to a logger instance.
67
  """
68
+
69
  def __init__(self, logger, log_level=logging.INFO):
70
  self.terminal = sys.stdout
71
  self.logger = logger
72
  self.log_level = log_level
73
+ self.linebuf = ""
74
 
75
  def __getattr__(self, attr):
76
  return getattr(self.terminal, attr)
77
 
78
  def write(self, buf):
79
  temp_linebuf = self.linebuf + buf
80
+ self.linebuf = ""
81
  for line in temp_linebuf.splitlines(True):
82
  # From the io.TextIOWrapper docs:
83
  # On output, if newline is None, any '\n' characters written
84
  # are translated to the system default line separator.
85
  # By default sys.stdout.write() expects '\n' newlines and then
86
  # translates them so this is still cross platform.
87
+ if line[-1] == "\n":
88
  self.logger.log(self.log_level, line.rstrip())
89
  else:
90
  self.linebuf += line
91
 
92
  def flush(self):
93
+ if self.linebuf != "":
94
  self.logger.log(self.log_level, self.linebuf.rstrip())
95
+ self.linebuf = ""
96
 
97
 
98
  def disable_torch_init():
 
100
  Disable the redundant torch default initialization to accelerate model creation.
101
  """
102
  import torch
103
+
104
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
105
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
106
 
 
110
  Check whether the text violates OpenAI moderation API.
111
  """
112
  url = "https://api.openai.com/v1/moderations"
113
+ headers = {
114
+ "Content-Type": "application/json",
115
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
116
+ }
117
  text = text.replace("\n", "")
118
  data = "{" + '"input": ' + f'"{text}"' + "}"
119
  data = data.encode("utf-8")
model/segment_anything/__init__.py CHANGED
@@ -4,12 +4,7 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- from .build_sam import (
8
- build_sam,
9
- build_sam_vit_h,
10
- build_sam_vit_l,
11
- build_sam_vit_b,
12
- sam_model_registry,
13
- )
14
- from .predictor import SamPredictor
15
  from .automatic_mask_generator import SamAutomaticMaskGenerator
 
 
 
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
 
 
 
 
 
 
7
  from .automatic_mask_generator import SamAutomaticMaskGenerator
8
+ from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
9
+ build_sam_vit_l, sam_model_registry)
10
+ from .predictor import SamPredictor
model/segment_anything/automatic_mask_generator.py CHANGED
@@ -4,32 +4,21 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
7
  import numpy as np
8
  import torch
9
  from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
 
11
- from typing import Any, Dict, List, Optional, Tuple
12
-
13
  from .modeling import Sam
14
  from .predictor import SamPredictor
15
- from .utils.amg import (
16
- MaskData,
17
- area_from_rle,
18
- batch_iterator,
19
- batched_mask_to_box,
20
- box_xyxy_to_xywh,
21
- build_all_layer_point_grids,
22
- calculate_stability_score,
23
- coco_encode_rle,
24
- generate_crop_boxes,
25
- is_box_near_crop_edge,
26
- mask_to_rle_pytorch,
27
- remove_small_regions,
28
- rle_to_mask,
29
- uncrop_boxes_xyxy,
30
- uncrop_masks,
31
- uncrop_points,
32
- )
33
 
34
 
35
  class SamAutomaticMaskGenerator:
@@ -115,7 +104,8 @@ class SamAutomaticMaskGenerator:
115
  "coco_rle",
116
  ], f"Unknown output_mode {output_mode}."
117
  if output_mode == "coco_rle":
118
- from pycocotools import mask as mask_utils # type: ignore # noqa: F401
 
119
 
120
  if min_mask_region_area > 0:
121
  import cv2 # type: ignore # noqa: F401
@@ -172,7 +162,9 @@ class SamAutomaticMaskGenerator:
172
 
173
  # Encode masks
174
  if self.output_mode == "coco_rle":
175
- mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
 
 
176
  elif self.output_mode == "binary_mask":
177
  mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
  else:
@@ -242,7 +234,9 @@ class SamAutomaticMaskGenerator:
242
  # Generate masks for this crop in batches
243
  data = MaskData()
244
  for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
- batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
 
 
246
  data.cat(batch_data)
247
  del batch_data
248
  self.predictor.reset_image()
@@ -275,7 +269,9 @@ class SamAutomaticMaskGenerator:
275
  # Run model on this batch
276
  transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
  in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
- in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 
 
279
  masks, iou_preds, _ = self.predictor.predict_torch(
280
  in_points[:, None, :],
281
  in_labels[:, None],
@@ -298,7 +294,9 @@ class SamAutomaticMaskGenerator:
298
 
299
  # Calculate stability score
300
  data["stability_score"] = calculate_stability_score(
301
- data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
 
 
302
  )
303
  if self.stability_score_thresh > 0.0:
304
  keep_mask = data["stability_score"] >= self.stability_score_thresh
@@ -309,7 +307,9 @@ class SamAutomaticMaskGenerator:
309
  data["boxes"] = batched_mask_to_box(data["masks"])
310
 
311
  # Filter boxes that touch crop boundaries
312
- keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 
 
313
  if not torch.all(keep_mask):
314
  data.filter(keep_mask)
315
 
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
  import numpy as np
10
  import torch
11
  from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12
 
 
 
13
  from .modeling import Sam
14
  from .predictor import SamPredictor
15
+ from .utils.amg import (MaskData, area_from_rle, batch_iterator,
16
+ batched_mask_to_box, box_xyxy_to_xywh,
17
+ build_all_layer_point_grids, calculate_stability_score,
18
+ coco_encode_rle, generate_crop_boxes,
19
+ is_box_near_crop_edge, mask_to_rle_pytorch,
20
+ remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
21
+ uncrop_masks, uncrop_points)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  class SamAutomaticMaskGenerator:
 
104
  "coco_rle",
105
  ], f"Unknown output_mode {output_mode}."
106
  if output_mode == "coco_rle":
107
+ from pycocotools import \
108
+ mask as mask_utils # type: ignore # noqa: F401
109
 
110
  if min_mask_region_area > 0:
111
  import cv2 # type: ignore # noqa: F401
 
162
 
163
  # Encode masks
164
  if self.output_mode == "coco_rle":
165
+ mask_data["segmentations"] = [
166
+ coco_encode_rle(rle) for rle in mask_data["rles"]
167
+ ]
168
  elif self.output_mode == "binary_mask":
169
  mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
170
  else:
 
234
  # Generate masks for this crop in batches
235
  data = MaskData()
236
  for (points,) in batch_iterator(self.points_per_batch, points_for_image):
237
+ batch_data = self._process_batch(
238
+ points, cropped_im_size, crop_box, orig_size
239
+ )
240
  data.cat(batch_data)
241
  del batch_data
242
  self.predictor.reset_image()
 
269
  # Run model on this batch
270
  transformed_points = self.predictor.transform.apply_coords(points, im_size)
271
  in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
272
+ in_labels = torch.ones(
273
+ in_points.shape[0], dtype=torch.int, device=in_points.device
274
+ )
275
  masks, iou_preds, _ = self.predictor.predict_torch(
276
  in_points[:, None, :],
277
  in_labels[:, None],
 
294
 
295
  # Calculate stability score
296
  data["stability_score"] = calculate_stability_score(
297
+ data["masks"],
298
+ self.predictor.model.mask_threshold,
299
+ self.stability_score_offset,
300
  )
301
  if self.stability_score_thresh > 0.0:
302
  keep_mask = data["stability_score"] >= self.stability_score_thresh
 
307
  data["boxes"] = batched_mask_to_box(data["masks"])
308
 
309
  # Filter boxes that touch crop boundaries
310
+ keep_mask = ~is_box_near_crop_edge(
311
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
312
+ )
313
  if not torch.all(keep_mask):
314
  data.filter(keep_mask)
315