jadechoghari commited on
Commit
442397e
·
verified ·
1 Parent(s): 3dd8fd5

Create model_UI.py

Browse files
Files changed (1) hide show
  1. model_UI.py +273 -0
model_UI.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+
7
+ IMAGE_TOKEN_INDEX = -200
8
+ DEFAULT_IMAGE_TOKEN = "<image>"
9
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
10
+ DEFAULT_IM_START_TOKEN = "<im_start>"
11
+ DEFAULT_IM_END_TOKEN = "<im_end>"
12
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
13
+
14
+ # Added by Ferret
15
+ DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
16
+ VOCAB_IMAGE_W = 1000
17
+ VOCAB_IMAGE_H = 1000
18
+ from conversation import conv_templates, SeparatorStyle
19
+ from builder import load_pretrained_model
20
+
21
+ from mm_utils import tokenizer_image_token, process_images
22
+
23
+ from PIL import Image
24
+ import math
25
+ import pdb
26
+ import numpy as np
27
+ from copy import deepcopy
28
+ from functools import partial
29
+
30
+ def disable_torch_init():
31
+ """
32
+ Disable the redundant torch default initialization to accelerate model creation.
33
+ """
34
+ import torch
35
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
36
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
37
+
38
+ def split_list(lst, n):
39
+ """Split a list into n (roughly) equal-sized chunks"""
40
+ chunk_size = math.ceil(len(lst) / n) # integer division
41
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
42
+
43
+ def get_chunk(lst, n, k):
44
+ chunks = split_list(lst, n)
45
+ return chunks[k]
46
+
47
+ def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
48
+ if mask is not None:
49
+ assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
50
+ coor_mask = np.zeros((raw_w, raw_h))
51
+ # Assume it samples a point.
52
+ if len(coor) == 2:
53
+ # Define window size
54
+ span = 5
55
+ # Make sure the window does not exceed array bounds
56
+ x_min = max(0, coor[0] - span)
57
+ x_max = min(raw_w, coor[0] + span + 1)
58
+ y_min = max(0, coor[1] - span)
59
+ y_max = min(raw_h, coor[1] + span + 1)
60
+ coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
61
+ assert (coor_mask==1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
62
+ elif len(coor) == 4:
63
+ # Box input or Sketch input.
64
+ coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
65
+ if mask is not None:
66
+ coor_mask = coor_mask * mask
67
+ coor_mask = torch.from_numpy(coor_mask)
68
+ try:
69
+ assert len(coor_mask.nonzero()) != 0
70
+ except:
71
+ pdb.set_trace()
72
+ return coor_mask
73
+
74
+ def get_task_from_file(file):
75
+ box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
76
+ # box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
77
+ # no_box = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
78
+ if any(task in file for task in box_in_tasks):
79
+ return 'box_in'
80
+ else:
81
+ return 'no_box_in'
82
+ # elif any(task in file for task in box_out_tasks):
83
+ # return 'box_out'
84
+ # elif any(task in file for task in no_box):
85
+ # return 'no_box'
86
+
87
+ def get_bbox_coor(box, ratio_w, ratio_h):
88
+ return box[0] * ratio_w, box[1] * ratio_h, box[2] * ratio_w, box[3] * ratio_h
89
+
90
+ def get_model_name_from_path(model_path):
91
+ if 'gemma' in model_path:
92
+ return 'ferret_gemma'
93
+ elif 'llama' or 'vicuna' in model_path:
94
+ return 'ferret_llama'
95
+ else:
96
+ raise ValueError(f"No model matched for {model_path}")
97
+
98
+ class UIData:
99
+ def __init__(self, data_path, image_path, args) -> None:
100
+ self.obj_list = json.load(open(data_path, 'r'))
101
+ self.image_path = image_path
102
+ self.args = args
103
+ self._ids = range(len(self.obj_list))
104
+ self.task = get_task_from_file(data_path)
105
+
106
+ @property
107
+ def ids(self):
108
+ return deepcopy(self._ids)
109
+
110
+ def __getitem__(self, idx):
111
+ i = self.obj_list[idx]
112
+
113
+ # image stuff
114
+ image_path_i = os.path.join(self.image_path, i['image'].split('/')[-1])
115
+ image = Image.open(image_path_i).convert('RGB')
116
+
117
+ q_turn = i['conversations'][0]['value']
118
+ if "<image>" in q_turn:
119
+ prompt = q_turn.split('\n')[1]
120
+ else:
121
+ prompt = q_turn
122
+ i['question'] = prompt
123
+ i['region_masks'] = None
124
+
125
+ if self.task == 'box_in':
126
+ ratio_w = VOCAB_IMAGE_W * 1.0 / i['image_w']
127
+ ratio_h = VOCAB_IMAGE_H * 1.0 / i['image_h']
128
+
129
+ box = i['box_x1y1x2y2'][0][0]
130
+ box_x1, box_y1, box_x2, box_y2 = box
131
+ box_x1_textvocab, box_y1_textvocab, box_x2_textvocab, box_y2_textvocab = get_bbox_coor(box=box, ratio_h=ratio_h, ratio_w=ratio_w)
132
+
133
+ if self.args.region_format == 'box':
134
+ region_coordinate_raw = [box_x1, box_y1, box_x2, box_y2]
135
+ if args.add_region_feature:
136
+ i['question'] = prompt.replace('<bbox_location0>', '[{}, {}, {}, {}] {}'.format(int(box_x1_textvocab), int(box_y1_textvocab), int(box_x2_textvocab), int(box_y2_textvocab), DEFAULT_REGION_FEA_TOKEN))
137
+ generated_mask = generate_mask_for_feature(region_coordinate_raw, raw_w=i['image_w'], raw_h=i['image_h'], mask=None)
138
+ i['region_masks'] = [generated_mask]
139
+ else:
140
+ i['question'] = prompt.replace('<bbox_location0>', '[{}, {}, {}, {}]'.format(int(box_x1_textvocab), int(box_y1_textvocab), int(box_x2_textvocab), int(box_y2_textvocab)))
141
+ else:
142
+ raise NotImplementedError(f'{self.args.region_format} is not supported.')
143
+
144
+ return image, i, image.size
145
+
146
+ def eval_model(args):
147
+ # Data
148
+ dataset = UIData(data_path=args.data_path, image_path=args.image_path, args=args)
149
+ data_ids = dataset.ids
150
+
151
+ # Model
152
+ disable_torch_init()
153
+ model_path = os.path.expanduser(args.model_path)
154
+ model_name = get_model_name_from_path(model_path)
155
+ tokenizer, model, image_processor, context_len = \
156
+ load_pretrained_model(model_path, args.model_base, model_name)
157
+
158
+ chunk_data_ids = get_chunk(data_ids, args.num_chunks, args.chunk_idx)
159
+ answers_folder = os.path.expanduser(args.answers_file)
160
+ os.makedirs(answers_folder, exist_ok=True)
161
+ answers_file = os.path.join(answers_folder, f'{args.chunk_idx}_of_{args.num_chunks}.jsonl')
162
+ ans_file = open(answers_file, "w")
163
+
164
+ for i, id in enumerate(tqdm(chunk_data_ids)):
165
+ img, ann, image_size = dataset[id]
166
+ image_path = ann['image']
167
+ qs = ann["question"]
168
+ cur_prompt = qs
169
+
170
+ if "<image>" in qs:
171
+ qs = qs.split('\n')[1]
172
+
173
+ if model.config.mm_use_im_start_end:
174
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
175
+ else:
176
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
177
+
178
+ conv = conv_templates[args.conv_mode].copy()
179
+ conv.append_message(conv.roles[0], qs)
180
+ conv.append_message(conv.roles[1], None)
181
+ prompt = conv.get_prompt()
182
+
183
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
184
+
185
+ if model.config.image_aspect_ratio == "square_nocrop":
186
+ image_tensor = image_processor.preprocess(img, return_tensors='pt', do_resize=True,
187
+ do_center_crop=False, size=[args.image_h, args.image_w])['pixel_values'][0]
188
+ elif model.config.image_aspect_ratio == "anyres":
189
+ image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[args.image_h, args.image_w])
190
+ image_tensor = process_images([img], image_processor, model.config, image_process_func=image_process_func)[0]
191
+ else:
192
+ image_tensor = process_images([img], image_processor, model.config)[0]
193
+
194
+ images = image_tensor.unsqueeze(0).to(args.data_type).cuda()
195
+
196
+ region_masks = ann['region_masks']
197
+
198
+ if region_masks is not None:
199
+ region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
200
+ else:
201
+ region_masks = None
202
+
203
+ with torch.inference_mode():
204
+ model.orig_forward = model.forward
205
+ model.forward = partial(
206
+ model.orig_forward,
207
+ region_masks=region_masks
208
+ )
209
+ output_ids = model.generate(
210
+ input_ids,
211
+ images=images,
212
+ region_masks=region_masks,
213
+ image_sizes=[image_size],
214
+ do_sample=True if args.temperature > 0 else False,
215
+ temperature=args.temperature,
216
+ top_p=args.top_p,
217
+ num_beams=args.num_beams,
218
+ max_new_tokens=args.max_new_tokens,
219
+ use_cache=True)
220
+ model.forward = model.orig_forward
221
+
222
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
223
+ outputs = outputs.strip()
224
+
225
+ if 'label' in ann:
226
+ label = ann['label']
227
+ elif len(ann['conversations']) > 1:
228
+ label = ann['conversations'][1]['value']
229
+ else:
230
+ label = None
231
+
232
+ ans_file.write(json.dumps({"id":ann['id'], # +1 offset
233
+ "image_path":image_path,
234
+ "prompt": cur_prompt,
235
+ "text": outputs,
236
+ "label": label,
237
+ }) + "\n")
238
+ ans_file.flush()
239
+ ans_file.close()
240
+
241
+
242
+ if __name__ == "__main__":
243
+ parser = argparse.ArgumentParser()
244
+ parser.add_argument("--model_path", type=str, default="facebook/opt-350m")
245
+ parser.add_argument("--vision_model_path", type=str, default=None)
246
+ parser.add_argument("--model_base", type=str, default=None)
247
+ parser.add_argument("--image_path", type=str, default="")
248
+ parser.add_argument("--data_path", type=str, default="")
249
+ parser.add_argument("--answers_file", type=str, default="")
250
+ parser.add_argument("--conv_mode", type=str, default="ferret_gemma_instruct",
251
+ help="[ferret_gemma_instruct,ferret_llama_3,ferret_vicuna_v1]")
252
+ parser.add_argument("--num_chunks", type=int, default=1)
253
+ parser.add_argument("--chunk_idx", type=int, default=0)
254
+ parser.add_argument("--image_w", type=int, default=336) # 224
255
+ parser.add_argument("--image_h", type=int, default=336) # 224
256
+ parser.add_argument("--add_region_feature", action="store_true")
257
+ parser.add_argument("--region_format", type=str, default="point", choices=["point", "box", "segment", "free_shape"])
258
+ parser.add_argument("--no_coor", action="store_true")
259
+ parser.add_argument("--temperature", type=float, default=0.001)
260
+ parser.add_argument("--top_p", type=float, default=None)
261
+ parser.add_argument("--num_beams", type=int, default=1)
262
+ parser.add_argument("--max_new_tokens", type=int, default=1024)
263
+ parser.add_argument("--data_type", type=str, default='fp16', choices=['fp16', 'bf16', 'fp32'])
264
+ args = parser.parse_args()
265
+
266
+ if args.data_type == 'fp16':
267
+ args.data_type = torch.float16
268
+ elif args.data_type == 'bf16':
269
+ args.data_type = torch.bfloat16
270
+ else:
271
+ args.data_type = torch.float32
272
+
273
+ eval_model(args)