jadechoghari commited on
Commit
8f4b3a0
1 Parent(s): 6e3e9da

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +217 -22
inference.py CHANGED
@@ -1,19 +1,200 @@
1
  import torch
2
  from PIL import Image
3
  from conversation import conv_templates
4
- from builder import load_pretrained_model # Assuming this is your custom model loader
5
  from functools import partial
 
 
 
6
  import numpy as np
7
  DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
8
  DEFAULT_IMAGE_TOKEN = "<image>"
9
  DEFAULT_IM_START_TOKEN = "<im_start>"
10
  DEFAULT_IM_END_TOKEN = "<im_end>"
 
 
 
 
11
 
12
  # define the task categories
13
  box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
14
  box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
15
  no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # function to generate the mask
18
  def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
19
  """
@@ -56,42 +237,58 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
56
  # define the image size required by clip
57
  image_size = {"height": 336, "width": 336}
58
 
59
- # process the image
60
- image_tensor = image_processor.preprocess(
61
- img,
62
- return_tensors='pt',
63
- do_resize=True,
64
- do_center_crop=False,
65
- size=(image_size['height'], image_size['width'])
66
- )['pixel_values'][0].unsqueeze(0)
67
 
68
- image_tensor = image_tensor.half().cuda()
 
 
 
69
 
70
  # generate the prompt per template requirement
71
  conv = conv_templates[conv_mode].copy()
72
  conv.append_message(conv.roles[0], prompt)
73
  conv.append_message(conv.roles[1], None)
74
  prompt_input = conv.get_prompt()
75
-
76
- # add the special tokens
77
- prompt_input = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_input
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
 
81
  # region mask logic (if region is provided)
82
  region_masks = None
83
  if add_region_feature and region is not None:
 
84
  raw_w, raw_h = img.size
85
- region_masks = generate_mask_for_feature(region, raw_w, raw_h).unsqueeze(0).cuda().half()
 
 
 
 
 
 
 
86
  region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
87
- prompt_input = prompt_input.replace("<bbox_location0>", f"[{region[0]}, {region[1]}, {region[2]}, {region[3]}] {DEFAULT_REGION_FEA_TOKEN}")
88
 
89
  # tokenize prompt
90
  # input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
91
 
92
- inputs = tokenizer(prompt_input, return_tensors='pt', padding=True)
93
- input_ids = inputs['input_ids'].cuda()
94
- attention_mask = inputs['attention_mask'].cuda()
95
 
96
  # generate model output
97
  with torch.inference_mode():
@@ -104,8 +301,7 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
104
  # explcit add of attention mask
105
  output_ids = model.generate(
106
  input_ids,
107
- images=image_tensor,
108
- attention_mask=attention_mask,
109
  max_new_tokens=1024,
110
  num_beams=1,
111
  region_masks=region_masks, # pass the region mask to the model
@@ -119,7 +315,6 @@ def infer_single_prompt(image_path, prompt, model_path, region=None, model_name=
119
 
120
  # We also define a task-specific inference function
121
  def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_feature=False):
122
- # region = torch.tensor(region).cuda()
123
  """
124
  Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
125
  """
@@ -141,4 +336,4 @@ def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_
141
  return infer_single_prompt(image_path, prompt, model_path)
142
 
143
  else:
144
- raise ValueError(f"Unknown task type: {task}")
 
1
  import torch
2
  from PIL import Image
3
  from conversation import conv_templates
4
+ from builder import load_pretrained_model
5
  from functools import partial
6
+ from typing import Optional, Callable
7
+ import ast
8
+ import math
9
  import numpy as np
10
  DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
11
  DEFAULT_IMAGE_TOKEN = "<image>"
12
  DEFAULT_IM_START_TOKEN = "<im_start>"
13
  DEFAULT_IM_END_TOKEN = "<im_end>"
14
+ VOCAB_IMAGE_W = 1000 # 224
15
+ VOCAB_IMAGE_H = 1000 # 224
16
+ IMAGE_TOKEN_INDEX = -200
17
+
18
 
19
  # define the task categories
20
  box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
21
  box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
22
  no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
23
 
24
+ def get_bbox_coor(box, ratio_w, ratio_h):
25
+ return box[0] * ratio_w, box[1] * ratio_h, box[2] * ratio_w, box[3] * ratio_h
26
+
27
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
28
+ if '<image>' in prompt:
29
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
30
+ input_ids = []
31
+ for i, chunk in enumerate(prompt_chunks):
32
+ input_ids.extend(chunk)
33
+ if i < len(prompt_chunks) - 1:
34
+ input_ids.append(image_token_index)
35
+ else:
36
+ input_ids = tokenizer(prompt).input_ids
37
+ # if return_tensors == 'pt':
38
+ # import torch
39
+ # input_ids = torch.tensor(input_ids).unsqueeze(0)
40
+
41
+ return input_ids
42
+
43
+
44
+ def expand2square(pil_img, background_color):
45
+ width, height = pil_img.size
46
+ if width == height:
47
+ return pil_img
48
+ elif width > height:
49
+ result = Image.new(pil_img.mode, (width, width), background_color)
50
+ result.paste(pil_img, (0, (width - height) // 2))
51
+ return result
52
+ else:
53
+ result = Image.new(pil_img.mode, (height, height), background_color)
54
+ result.paste(pil_img, ((height - width) // 2, 0))
55
+ return result
56
+
57
+ def select_best_resolution(original_size, possible_resolutions):
58
+ """
59
+ Selects the best resolution from a list of possible resolutions based on the original size.
60
+
61
+ Args:
62
+ original_size (tuple): The original size of the image in the format (width, height).
63
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
64
+
65
+ Returns:
66
+ tuple: The best fit resolution in the format (width, height).
67
+ """
68
+ original_width, original_height = original_size
69
+ best_fit = None
70
+ max_effective_resolution = 0
71
+ min_wasted_resolution = float('inf')
72
+
73
+ for width, height in possible_resolutions:
74
+ scale = min(width / original_width, height / original_height)
75
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
76
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
77
+ wasted_resolution = (width * height) - effective_resolution
78
+
79
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
80
+ max_effective_resolution = effective_resolution
81
+ min_wasted_resolution = wasted_resolution
82
+ best_fit = (width, height)
83
+
84
+ return best_fit
85
+
86
+ def divide_to_patches(image, patch_size):
87
+ """
88
+ Divides an image into patches of a specified size.
89
+
90
+ Args:
91
+ image (PIL.Image.Image): The input image.
92
+ patch_size (int): The size of each patch.
93
+
94
+ Returns:
95
+ list: A list of PIL.Image.Image objects representing the patches.
96
+ """
97
+ patches = []
98
+ width, height = image.size
99
+ for i in range(0, height, patch_size):
100
+ for j in range(0, width, patch_size):
101
+ box = (j, i, j + patch_size, i + patch_size)
102
+ patch = image.crop(box)
103
+ patches.append(patch)
104
+
105
+ return patches
106
+ def resize_and_pad_image(image, target_resolution, is_pad=False):
107
+ """
108
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
109
+ Args:
110
+ image (PIL.Image.Image): The input image.
111
+ target_resolution (tuple): The target resolution (width, height) of the image.
112
+ Returns:
113
+ PIL.Image.Image: The resized and padded image.
114
+ """
115
+ original_width, original_height = image.size
116
+ target_width, target_height = target_resolution
117
+
118
+ if is_pad:
119
+ scale_w = target_width / original_width
120
+ scale_h = target_height / original_height
121
+
122
+ if scale_w < scale_h:
123
+ new_width = target_width
124
+ new_height = min(math.ceil(original_height * scale_w), target_height)
125
+ else:
126
+ new_height = target_height
127
+ new_width = min(math.ceil(original_width * scale_h), target_width)
128
+
129
+ # Resize the image
130
+ resized_image = image.resize((new_width, new_height))
131
+
132
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
133
+ paste_x = (target_width - new_width) // 2
134
+ paste_y = (target_height - new_height) // 2
135
+ new_image.paste(resized_image, (paste_x, paste_y))
136
+ else:
137
+ new_image = image.resize((target_width, target_height))
138
+
139
+ return new_image
140
+
141
+ def process_anyres_image(image, processor, grid_pinpoints, image_process_func: Optional[Callable] = None):
142
+ """
143
+ Process an image with variable resolutions.
144
+
145
+ Args:
146
+ image (PIL.Image.Image): The input image to be processed.
147
+ processor: The image processor object.
148
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
149
+
150
+ Returns:
151
+ torch.Tensor: A tensor containing the processed image patches.
152
+ """
153
+ if type(grid_pinpoints) is list:
154
+ possible_resolutions = grid_pinpoints
155
+ else:
156
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
157
+
158
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
159
+
160
+ # FIXME: not sure if do_pad or undo_pad may affect the referring side
161
+ image_padded = resize_and_pad_image(image, best_resolution, is_pad=False)
162
+
163
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
164
+
165
+ if image_process_func:
166
+ resized_image_h, resized_image_w = image_process_func.keywords['size']
167
+ image_original_resize = image.resize((resized_image_w, resized_image_h))
168
+ image_patches = [image_original_resize] + patches
169
+ image_patches = [image_process_func(image_patch)['pixel_values'][0]
170
+ for image_patch in image_patches]
171
+ else:
172
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
173
+ image_patches = [image_original_resize] + patches
174
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
175
+ for image_patch in image_patches]
176
+
177
+ return torch.stack(image_patches, dim=0)
178
+
179
+
180
+ def process_images(images, image_processor, model_cfg, image_process_func: Optional[Callable] = None):
181
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
182
+ new_images = []
183
+ if image_aspect_ratio == 'pad':
184
+ for image in images:
185
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
186
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
187
+ new_images.append(image)
188
+ elif image_aspect_ratio == "anyres":
189
+ # image_processor(images, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])['pixel_values']
190
+ for image in images:
191
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints, image_process_func=image_process_func)
192
+ new_images.append(image)
193
+ else:
194
+ return image_processor(images, return_tensors='pt')['pixel_values']
195
+ if all(x.shape == new_images[0].shape for x in new_images):
196
+ new_images = torch.stack(new_images, dim=0)
197
+ return new_images
198
  # function to generate the mask
199
  def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
200
  """
 
237
  # define the image size required by clip
238
  image_size = {"height": 336, "width": 336}
239
 
240
+ if "<image>" in prompt:
241
+ prompt = prompt.split('\n')[1]
 
 
 
 
 
 
242
 
243
+ if model.config.mm_use_im_start_end:
244
+ prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
245
+ else:
246
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
247
 
248
  # generate the prompt per template requirement
249
  conv = conv_templates[conv_mode].copy()
250
  conv.append_message(conv.roles[0], prompt)
251
  conv.append_message(conv.roles[1], None)
252
  prompt_input = conv.get_prompt()
253
+
254
+ input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
255
+
256
+ # raw_w, raw_h = img.size # check if shouldnt be width and height
257
+ raw_w = image_size["width"]
258
+ raw_h = image_size["height"]
259
+ if model.config.image_aspect_ratio == "square_nocrop":
260
+ image_tensor = image_processor.preprocess(img, return_tensors='pt', do_resize=True,
261
+ do_center_crop=False, size=[raw_h, raw_w])['pixel_values'][0]
262
+ elif model.config.image_aspect_ratio == "anyres":
263
+ image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[raw_h, raw_h])
264
+ image_tensor = process_images([img], image_processor, model.config, image_process_func=image_process_func)[0]
265
+ else:
266
+ image_tensor = process_images([img], image_processor, model.config)[0]
267
+
268
+ images = image_tensor.unsqueeze(0).to(torch.float16).cuda()
269
 
270
 
271
 
272
  # region mask logic (if region is provided)
273
  region_masks = None
274
  if add_region_feature and region is not None:
275
+ # box_in is true
276
  raw_w, raw_h = img.size
277
+ ratio_w = VOCAB_IMAGE_W * 1.0 / raw_w
278
+ ratio_h = VOCAB_IMAGE_H * 1.0 / raw_h
279
+ # preprocess the region
280
+ box_x1, box_y1, box_x2, box_y2 = region
281
+ box_x1_textvocab, box_y1_textvocab, box_x2_textvocab, box_y2_textvocab = get_bbox_coor(box=region, ratio_h=ratio_h, ratio_w=ratio_w)
282
+ region_coordinate_raw = [box_x1, box_y1, box_x2, box_y2]
283
+
284
+ region_masks = generate_mask_for_feature(region_coordinate_raw, raw_w, raw_h).unsqueeze(0).cuda().half()
285
  region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
286
+ prompt_input = prompt_input.replace("<bbox_location0>", f"[{box_x1_textvocab}, {box_y1_textvocab}, {box_x2_textvocab}, {box_y2_textvocab}] {DEFAULT_REGION_FEA_TOKEN}")
287
 
288
  # tokenize prompt
289
  # input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
290
 
291
+
 
 
292
 
293
  # generate model output
294
  with torch.inference_mode():
 
301
  # explcit add of attention mask
302
  output_ids = model.generate(
303
  input_ids,
304
+ images=images,
 
305
  max_new_tokens=1024,
306
  num_beams=1,
307
  region_masks=region_masks, # pass the region mask to the model
 
315
 
316
  # We also define a task-specific inference function
317
  def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_feature=False):
 
318
  """
319
  Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
320
  """
 
336
  return infer_single_prompt(image_path, prompt, model_path)
337
 
338
  else:
339
+ raise ValueError(f"Unknown task type: {task}")