Xenova HF staff commited on
Commit
a8eefc9
·
verified ·
1 Parent(s): f0c1b14

Delete processing_florence2.py

Browse files
Files changed (1) hide show
  1. processing_florence2.py +0 -1088
processing_florence2.py DELETED
@@ -1,1088 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Microsoft and The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Processor class for Florence-2.
17
- """
18
-
19
- import re
20
- import logging
21
- from typing import List, Optional, Union
22
- import numpy as np
23
-
24
- import torch
25
-
26
- from transformers.feature_extraction_utils import BatchFeature
27
- from transformers.image_utils import ImageInput, is_valid_image
28
- from transformers.processing_utils import ProcessorMixin
29
- from transformers.tokenization_utils_base import (
30
- PaddingStrategy,
31
- PreTokenizedInput,
32
- TextInput,
33
- TruncationStrategy,
34
- )
35
- from transformers.utils import TensorType
36
-
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
- # Copied from transformers.models.idefics2.processing_idefics2.is_url
41
- def is_url(val) -> bool:
42
- return isinstance(val, str) and val.startswith("http")
43
-
44
- # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
45
- def is_image_or_image_url(elem):
46
- return is_url(elem) or is_valid_image(elem)
47
-
48
-
49
- def _is_str_or_image(elem):
50
- return isinstance(elem, (str)) or is_image_or_image_url(elem)
51
-
52
-
53
- class Florence2Processor(ProcessorMixin):
54
- r"""
55
- Constructs a Florence2 processor which wraps a Florence2 image processor and a Florence2 tokenizer into a single processor.
56
-
57
- [`Florence2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and [`BartTokenizerFast`]. See the
58
- [`~Florence2Processor.__call__`] and [`~Florence2Processor.decode`] for more information.
59
-
60
- Args:
61
- image_processor ([`CLIPImageProcessor`], *optional*):
62
- The image processor is a required input.
63
- tokenizer ([`BartTokenizerFast`], *optional*):
64
- The tokenizer is a required input.
65
- """
66
-
67
- attributes = ["image_processor", "tokenizer"]
68
- image_processor_class = "CLIPImageProcessor"
69
- tokenizer_class = ("BartTokenizer", "BartTokenizerFast")
70
-
71
- def __init__(
72
- self,
73
- image_processor=None,
74
- tokenizer=None,
75
- ):
76
- if image_processor is None:
77
- raise ValueError("You need to specify an `image_processor`.")
78
- if tokenizer is None:
79
- raise ValueError("You need to specify a `tokenizer`.")
80
- if not hasattr(image_processor, "image_seq_length"):
81
- raise ValueError("Image processor is missing an `image_seq_length` attribute.")
82
-
83
- self.image_seq_length = image_processor.image_seq_length
84
-
85
- tokens_to_add = {
86
- 'additional_special_tokens': \
87
- tokenizer.additional_special_tokens + \
88
- ['<od>', '</od>', '<ocr>', '</ocr>'] + \
89
- [f'<loc_{x}>' for x in range(1000)] + \
90
- ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>']
91
- }
92
- tokenizer.add_special_tokens(tokens_to_add)
93
-
94
- self.tasks_answer_post_processing_type = {
95
- '<OCR>': 'pure_text',
96
- '<OCR_WITH_REGION>': 'ocr',
97
- '<CAPTION>': 'pure_text',
98
- '<DETAILED_CAPTION>': 'pure_text',
99
- '<MORE_DETAILED_CAPTION>': 'pure_text',
100
- '<OD>': 'description_with_bboxes',
101
- '<DENSE_REGION_CAPTION>': 'description_with_bboxes',
102
- '<CAPTION_TO_PHRASE_GROUNDING>': "phrase_grounding",
103
- '<REFERRING_EXPRESSION_SEGMENTATION>': 'polygons',
104
- '<REGION_TO_SEGMENTATION>': 'polygons',
105
- '<OPEN_VOCABULARY_DETECTION>': 'description_with_bboxes_or_polygons',
106
- '<REGION_TO_CATEGORY>': 'pure_text',
107
- '<REGION_TO_DESCRIPTION>': 'pure_text',
108
- '<REGION_TO_OCR>': 'pure_text',
109
- '<REGION_PROPOSAL>': 'bboxes'
110
- }
111
-
112
- self.task_prompts_without_inputs = {
113
- '<OCR>': 'What is the text in the image?',
114
- '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
115
- '<CAPTION>': 'What does the image describe?',
116
- '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
117
- '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
118
- '<OD>': 'Locate the objects with category name in the image.',
119
- '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
120
- '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
121
- }
122
-
123
- self.task_prompts_with_input = {
124
- '<CAPTION_TO_PHRASE_GROUNDING>': "Locate the phrases in the caption: {input}",
125
- '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
126
- '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
127
- '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
128
- '<REGION_TO_CATEGORY>': 'What is the region {input}?',
129
- '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
130
- '<REGION_TO_OCR>': 'What text is in the region {input}?',
131
- }
132
-
133
- self.post_processor = Florence2PostProcesser(tokenizer=tokenizer)
134
-
135
-
136
- super().__init__(image_processor, tokenizer)
137
-
138
- def _construct_prompts(self, text):
139
- # replace the task tokens with the task prompts if task token is in the text
140
- prompts = []
141
- for _text in text:
142
- # 1. fixed task prompts without additional inputs
143
- for task_token, task_prompt in self.task_prompts_without_inputs.items():
144
- if task_token in _text:
145
- assert _text == task_token, f"Task token {task_token} should be the only token in the text."
146
- _text = task_prompt
147
- break
148
- # 2. task prompts with additional inputs
149
- for task_token, task_prompt in self.task_prompts_with_input.items():
150
- if task_token in _text:
151
- _text = task_prompt.format(input=_text.replace(task_token, ''))
152
- break
153
- prompts.append(_text)
154
- return prompts
155
-
156
- def __call__(
157
- self,
158
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
159
- images: ImageInput = None,
160
- tokenize_newline_separately: bool = True,
161
- padding: Union[bool, str, PaddingStrategy] = False,
162
- truncation: Union[bool, str, TruncationStrategy] = None,
163
- max_length=None,
164
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
165
- do_resize: bool = None,
166
- do_normalize: bool = None,
167
- image_mean: Optional[Union[float, List[float]]] = None,
168
- image_std: Optional[Union[float, List[float]]] = None,
169
- data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
170
- input_data_format: Optional[
171
- Union[str, "ChannelDimension"] # noqa: F821
172
- ] = None,
173
- resample: "PILImageResampling" = None, # noqa: F821
174
- do_convert_rgb: bool = None,
175
- do_thumbnail: bool = None,
176
- do_align_long_axis: bool = None,
177
- do_rescale: bool = None,
178
- ) -> BatchFeature:
179
- """
180
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
181
- and `kwargs` arguments to BartTokenizerFast's [`~BartTokenizerFast.__call__`] if `text` is not `None` to encode
182
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
183
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
184
- of the above two methods for more information.
185
-
186
- Args:
187
- text (`str`, `List[str]`, `List[List[str]]`):
188
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
189
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
190
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
191
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
192
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
193
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
194
- number of channels, H and W are image height and width.
195
- tokenize_newline_separately (`bool`, defaults to `True`):
196
- Adds a separately tokenized '\n' at the end of the prompt.
197
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
198
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
199
- index) among:
200
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
201
- sequence if provided).
202
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
203
- acceptable input length for the model if that argument is not provided.
204
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
205
- lengths).
206
- max_length (`int`, *optional*):
207
- Maximum length of the returned list and optionally padding length (see above).
208
- truncation (`bool`, *optional*):
209
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
210
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
211
- If set, will return tensors of a particular framework. Acceptable values are:
212
-
213
- - `'tf'`: Return TensorFlow `tf.constant` objects.
214
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
215
- - `'np'`: Return NumPy `np.ndarray` objects.
216
- - `'jax'`: Return JAX `jnp.ndarray` objects.
217
-
218
- Returns:
219
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
220
-
221
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
222
- is provided, the `input_ids` will also contain the suffix input ids.
223
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
224
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
225
- `None`).
226
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
227
- - **labels** -- Labels compatible with training if `suffix` is not None
228
- """
229
-
230
- return_token_type_ids = False
231
-
232
- if images is None:
233
- raise ValueError("`images` are expected as arguments to a `Florence2Processor` instance.")
234
- if text is None:
235
- logger.warning_once(
236
- "You are using Florence-2 without a text prompt."
237
- )
238
- text = ""
239
-
240
- if isinstance(text, List) and isinstance(images, List):
241
- if len(images) < len(text):
242
- raise ValueError(
243
- f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image."
244
- )
245
- if _is_str_or_image(text):
246
- text = [text]
247
- elif isinstance(text, list) and _is_str_or_image(text[0]):
248
- pass
249
-
250
- pixel_values = self.image_processor(
251
- images,
252
- do_resize=do_resize,
253
- do_normalize=do_normalize,
254
- return_tensors=return_tensors,
255
- image_mean=image_mean,
256
- image_std=image_std,
257
- input_data_format=input_data_format,
258
- data_format=data_format,
259
- resample=resample,
260
- do_convert_rgb=do_convert_rgb,
261
- )["pixel_values"]
262
-
263
- if max_length is not None:
264
- max_length -= self.image_seq_length # max_length has to account for the image tokens
265
-
266
- text = self._construct_prompts(text)
267
-
268
- inputs = self.tokenizer(
269
- text,
270
- return_tensors=return_tensors,
271
- padding=padding,
272
- max_length=max_length,
273
- truncation=truncation,
274
- return_token_type_ids=return_token_type_ids,
275
- )
276
-
277
- return_data = {**inputs, "pixel_values": pixel_values}
278
-
279
- if return_token_type_ids:
280
- labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
281
- return_data.update({"labels": labels})
282
- return BatchFeature(data=return_data)
283
-
284
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Florence2
285
- def batch_decode(self, *args, **kwargs):
286
- """
287
- This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
288
- refer to the docstring of this method for more information.
289
- """
290
- return self.tokenizer.batch_decode(*args, **kwargs)
291
-
292
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Florence2
293
- def decode(self, *args, **kwargs):
294
- """
295
- This method forwards all its arguments to BartTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
296
- the docstring of this method for more information.
297
- """
298
- return self.tokenizer.decode(*args, **kwargs)
299
-
300
- @property
301
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Florence2
302
- def model_input_names(self):
303
- tokenizer_input_names = self.tokenizer.model_input_names
304
- image_processor_input_names = self.image_processor.model_input_names
305
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
-
307
- def post_process_generation(self, text, task, image_size):
308
- """
309
- Post-process the output of the model to each of the task outputs.
310
-
311
- Args:
312
- text (`str`): The text to post-process.
313
- task (`str`): The task to post-process the text for.
314
- image_size (`Tuple[int, int]`): The size of the image. height x width.
315
- """
316
-
317
- task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
- task_answer = self.post_processor(
319
- text=text,
320
- image_size=image_size,
321
- parse_tasks=task_answer_post_processing_type,
322
- )[task_answer_post_processing_type]
323
-
324
- if task_answer_post_processing_type == 'pure_text':
325
- final_answer = task_answer
326
- # remove the special tokens
327
- final_answer = final_answer.replace('<s>', '').replace('</s>', '')
328
- elif task_answer_post_processing_type in ['od', 'description_with_bboxes', 'bboxes']:
329
- od_instances = task_answer
330
- bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
- labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
- final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
333
- elif task_answer_post_processing_type in ['ocr']:
334
- bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
- labels = [str(_od_instance['text']) for _od_instance in task_answer]
336
- final_answer = {'quad_boxes': bboxes, 'labels': labels}
337
- elif task_answer_post_processing_type in ['phrase_grounding']:
338
- bboxes = []
339
- labels = []
340
- for _grounded_phrase in task_answer:
341
- for _bbox in _grounded_phrase['bbox']:
342
- bboxes.append(_bbox)
343
- labels.append(_grounded_phrase['cat_name'])
344
- final_answer = {'bboxes': bboxes, 'labels': labels}
345
- elif task_answer_post_processing_type in ['description_with_polygons', 'polygons']:
346
- labels = []
347
- polygons = []
348
- for result in task_answer:
349
- label = result['cat_name']
350
- _polygons = result['polygons']
351
- labels.append(label)
352
- polygons.append(_polygons)
353
- final_answer = {'polygons': polygons, 'labels': labels}
354
- elif task_answer_post_processing_type in ['description_with_bboxes_or_polygons']:
355
- bboxes = []
356
- bboxes_labels = []
357
- polygons = []
358
- polygons_labels = []
359
- for result in task_answer:
360
- label = result['cat_name']
361
- if 'polygons' in result:
362
- _polygons = result['polygons']
363
- polygons.append(_polygons)
364
- polygons_labels.append(label)
365
- else:
366
- _bbox = result['bbox']
367
- bboxes.append(_bbox)
368
- bboxes_labels.append(label)
369
- final_answer = {'bboxes': bboxes, 'bboxes_labels': bboxes_labels, 'polygons': polygons, 'polygons_labels': polygons_labels}
370
- else:
371
- raise ValueError('Unknown task answer post processing type: {}'.format(task_answer_post_processing_type))
372
-
373
- final_answer = {
374
- task: final_answer}
375
- return final_answer
376
-
377
- class BoxQuantizer(object):
378
- def __init__(self, mode, bins):
379
- self.mode = mode
380
- self.bins = bins
381
-
382
- def quantize(self, boxes: torch.Tensor, size):
383
- bins_w, bins_h = self.bins # Quantization bins.
384
- size_w, size_h = size # Original image size.
385
- size_per_bin_w = size_w / bins_w
386
- size_per_bin_h = size_h / bins_h
387
- xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
388
-
389
- if self.mode == 'floor':
390
- quantized_xmin = (
391
- xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
392
- quantized_ymin = (
393
- ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
394
- quantized_xmax = (
395
- xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
396
- quantized_ymax = (
397
- ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)
398
-
399
- elif self.mode == 'round':
400
- raise NotImplementedError()
401
-
402
- else:
403
- raise ValueError('Incorrect quantization type.')
404
-
405
- quantized_boxes = torch.cat(
406
- (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
407
- ).int()
408
-
409
- return quantized_boxes
410
-
411
- def dequantize(self, boxes: torch.Tensor, size):
412
- bins_w, bins_h = self.bins # Quantization bins.
413
- size_w, size_h = size # Original image size.
414
- size_per_bin_w = size_w / bins_w
415
- size_per_bin_h = size_h / bins_h
416
- xmin, ymin, xmax, ymax = boxes.split(1, dim=-1) # Shape: 4 * [N, 1].
417
-
418
- if self.mode == 'floor':
419
- # Add 0.5 to use the center position of the bin as the coordinate.
420
- dequantized_xmin = (xmin + 0.5) * size_per_bin_w
421
- dequantized_ymin = (ymin + 0.5) * size_per_bin_h
422
- dequantized_xmax = (xmax + 0.5) * size_per_bin_w
423
- dequantized_ymax = (ymax + 0.5) * size_per_bin_h
424
-
425
- elif self.mode == 'round':
426
- raise NotImplementedError()
427
-
428
- else:
429
- raise ValueError('Incorrect quantization type.')
430
-
431
- dequantized_boxes = torch.cat(
432
- (dequantized_xmin, dequantized_ymin,
433
- dequantized_xmax, dequantized_ymax), dim=-1
434
- )
435
-
436
- return dequantized_boxes
437
-
438
-
439
- class CoordinatesQuantizer(object):
440
- """
441
- Quantize coornidates (Nx2)
442
- """
443
-
444
- def __init__(self, mode, bins):
445
- self.mode = mode
446
- self.bins = bins
447
-
448
- def quantize(self, coordinates: torch.Tensor, size):
449
- bins_w, bins_h = self.bins # Quantization bins.
450
- size_w, size_h = size # Original image size.
451
- size_per_bin_w = size_w / bins_w
452
- size_per_bin_h = size_h / bins_h
453
- assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
454
- x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
455
-
456
- if self.mode == 'floor':
457
- quantized_x = (x / size_per_bin_w).floor().clamp(0, bins_w - 1)
458
- quantized_y = (y / size_per_bin_h).floor().clamp(0, bins_h - 1)
459
-
460
- elif self.mode == 'round':
461
- raise NotImplementedError()
462
-
463
- else:
464
- raise ValueError('Incorrect quantization type.')
465
-
466
- quantized_coordinates = torch.cat(
467
- (quantized_x, quantized_y), dim=-1
468
- ).int()
469
-
470
- return quantized_coordinates
471
-
472
- def dequantize(self, coordinates: torch.Tensor, size):
473
- bins_w, bins_h = self.bins # Quantization bins.
474
- size_w, size_h = size # Original image size.
475
- size_per_bin_w = size_w / bins_w
476
- size_per_bin_h = size_h / bins_h
477
- assert coordinates.shape[-1] == 2, 'coordinates should be shape (N, 2)'
478
- x, y = coordinates.split(1, dim=-1) # Shape: 4 * [N, 1].
479
-
480
- if self.mode == 'floor':
481
- # Add 0.5 to use the center position of the bin as the coordinate.
482
- dequantized_x = (x + 0.5) * size_per_bin_w
483
- dequantized_y = (y + 0.5) * size_per_bin_h
484
-
485
- elif self.mode == 'round':
486
- raise NotImplementedError()
487
-
488
- else:
489
- raise ValueError('Incorrect quantization type.')
490
-
491
- dequantized_coordinates = torch.cat(
492
- (dequantized_x, dequantized_y), dim=-1
493
- )
494
-
495
- return dequantized_coordinates
496
-
497
-
498
- class Florence2PostProcesser(object):
499
- """
500
- Florence-2 post process for converting text prediction to various tasks results.
501
-
502
- Args:
503
- config: A dict of configs.
504
- tokenizer: A tokenizer for decoding text to spans.
505
- sample config:
506
- UNIFIED_POST_PROCESS:
507
- # commom configs
508
- NUM_BBOX_HEIGHT_BINS: 1000
509
- NUM_BBOX_WIDTH_BINS: 1000
510
- COORDINATES_HEIGHT_BINS: 1000
511
- COORDINATES_WIDTH_BINS: 1000
512
- # task specific configs, override the common configs
513
- PRASE_TASKS:
514
- - TASK_NAME: 'video_dense_caption'
515
- PATTERN: 'r<time_(\d+)><time_(\d+)>([a-zA-Z0-9 ]+)'
516
- SCORE_MODE: 'avg_cat_name_scores'
517
- NUM_BINS: 100
518
- - TASK_NAME: 'od'
519
- PATTERN: 'r<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>([a-zA-Z0-9 ]+)'
520
- SCORE_MODE: 'avg_cat_name_scores'
521
-
522
- Returns:
523
- parsed_dict (dict): A dict of parsed results.
524
- """
525
- def __init__(
526
- self,
527
- tokenizer=None
528
- ):
529
- parse_tasks = []
530
- parse_task_configs = {}
531
- config = self._create_default_config()
532
- for task in config['PARSE_TASKS']:
533
- parse_tasks.append(task['TASK_NAME'])
534
- parse_task_configs[task['TASK_NAME']] = task
535
-
536
- self.config = config
537
- self.parse_tasks = parse_tasks
538
- self.parse_tasks_configs = parse_task_configs
539
-
540
- self.tokenizer = tokenizer
541
- if self.tokenizer is not None:
542
- self.all_special_tokens = set(self.tokenizer.all_special_tokens)
543
-
544
- self.init_quantizers()
545
- self.black_list_of_phrase_grounding = self._create_black_list_of_phrase_grounding()
546
-
547
- def _create_black_list_of_phrase_grounding(self):
548
- black_list = {}
549
-
550
- if 'phrase_grounding' in self.parse_tasks and self.parse_tasks_configs['phrase_grounding']['FILTER_BY_BLACK_LIST']:
551
- black_list = set(
552
- ['it', 'I', 'me', 'mine',
553
- 'you', 'your', 'yours',
554
- 'he', 'him', 'his',
555
- 'she', 'her', 'hers',
556
- 'they', 'them', 'their', 'theirs',
557
- 'one', 'oneself',
558
- 'we', 'us', 'our', 'ours',
559
- 'you', 'your', 'yours',
560
- 'they', 'them', 'their', 'theirs',
561
- 'mine', 'yours', 'his', 'hers', 'its',
562
- 'ours', 'yours', 'theirs',
563
- 'myself', 'yourself', 'himself', 'herself', 'itself',
564
- 'ourselves', 'yourselves', 'themselves',
565
- 'this', 'that',
566
- 'these', 'those',
567
- 'who', 'whom', 'whose', 'which', 'what',
568
- 'who', 'whom', 'whose', 'which', 'that',
569
- 'all', 'another', 'any', 'anybody', 'anyone', 'anything',
570
- 'each', 'everybody', 'everyone', 'everything',
571
- 'few', 'many', 'nobody', 'none', 'one', 'several',
572
- 'some', 'somebody', 'someone', 'something',
573
- 'each other', 'one another',
574
- 'myself', 'yourself', 'himself', 'herself', 'itself',
575
- 'ourselves', 'yourselves', 'themselves',
576
- 'the image', 'image', 'images', 'the', 'a', 'an', 'a group',
577
- 'other objects', 'lots', 'a set',
578
- ]
579
- )
580
-
581
- return black_list
582
-
583
- def _create_default_config(self):
584
- config = {
585
- 'NUM_BBOX_HEIGHT_BINS': 1000,
586
- 'NUM_BBOX_WIDTH_BINS': 1000,
587
- 'BOX_QUANTIZATION_MODE': 'floor',
588
- 'COORDINATES_HEIGHT_BINS': 1000,
589
- 'COORDINATES_WIDTH_BINS': 1000,
590
- 'COORDINATES_QUANTIZATION_MODE': 'floor',
591
- 'PARSE_TASKS': [
592
- {
593
- 'TASK_NAME': 'od',
594
- 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
595
- },
596
- {
597
- 'TASK_NAME': 'ocr',
598
- 'PATTERN': r'(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>',
599
- 'AREA_THRESHOLD': 0.01
600
- },
601
- {
602
- 'TASK_NAME': 'phrase_grounding',
603
- 'FILTER_BY_BLACK_LIST': True
604
- },
605
- {
606
- 'TASK_NAME': 'pure_text',
607
- },
608
- {
609
- 'TASK_NAME': 'description_with_bboxes',
610
- },
611
- {
612
- 'TASK_NAME': 'description_with_polygons',
613
- },
614
- {
615
- 'TASK_NAME': 'polygons',
616
- },
617
- {
618
- 'TASK_NAME': 'bboxes',
619
- },
620
- {
621
- 'TASK_NAME': 'description_with_bboxes_or_polygons',
622
- }
623
- ]
624
- }
625
-
626
- return config
627
-
628
- def init_quantizers(self):
629
- # we have box_quantizer (od, grounding) and coordinates_quantizer (ocr, referring_segmentation)
630
- num_bbox_height_bins = self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
631
- num_bbox_width_bins = self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
632
- box_quantization_mode = self.config.get('BOX_QUANTIZATION_MODE', 'floor')
633
- self.box_quantizer = BoxQuantizer(
634
- box_quantization_mode,
635
- (num_bbox_width_bins, num_bbox_height_bins),
636
- )
637
-
638
- num_bbox_height_bins = self.config['COORDINATES_HEIGHT_BINS'] if 'COORDINATES_HEIGHT_BINS' in self.config else self.config.get('NUM_BBOX_HEIGHT_BINS', 1000)
639
- num_bbox_width_bins = self.config['COORDINATES_WIDTH_BINS'] if 'COORDINATES_WIDTH_BINS' in self.config else self.config.get('NUM_BBOX_WIDTH_BINS', 1000)
640
- box_quantization_mode = self.config.get('COORDINATES_QUANTIZATION_MODE') if 'COORDINATES_QUANTIZATION_MODE' in self.config else self.config.get('BOX_QUANTIZATION_MODE', 'floor')
641
- self.coordinates_quantizer = CoordinatesQuantizer(
642
- box_quantization_mode,
643
- (num_bbox_width_bins, num_bbox_height_bins),
644
- )
645
-
646
- def decode_with_spans(self, tokenizer, token_ids):
647
- filtered_tokens = tokenizer.convert_ids_to_tokens(
648
- token_ids, skip_special_tokens=False)
649
- assert len(filtered_tokens) == len(token_ids)
650
-
651
- # To avoid mixing byte-level and unicode for byte-level BPT
652
- # we need to build string separately for added tokens and byte-level tokens
653
- # cf. https://github.com/huggingface/transformers/issues/1133
654
- sub_texts = []
655
- for token in filtered_tokens:
656
- if token in self.all_special_tokens:
657
- sub_texts.append(token)
658
- else:
659
- if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
- sub_text = tokenizer.convert_tokens_to_string([token])
661
- elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
- # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
- # Note: Do not strip sub_text as it may have functional whitespace
664
- sub_text = token.replace('▁', ' ')
665
- else:
666
- raise ValueError(f'type {type(tokenizer)} not supported')
667
- sub_texts.append(sub_text)
668
-
669
- text = ''
670
- spans = []
671
- for sub_text in sub_texts:
672
- span = (len(text), len(text) + len(sub_text)) # [start index, end index).
673
- text += sub_text
674
- spans.append(span)
675
-
676
- # Text format:
677
- # 1. T5Tokenizer/T5TokenizerFast:
678
- # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
- # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
- # 2. BartTokenizer (need to double check):
681
- # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
- # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
- return text, spans
684
-
685
- def parse_od_from_text_and_spans(
686
- self,
687
- text,
688
- pattern,
689
- image_size,
690
- phrase_centric=False
691
- ):
692
- parsed = list(re.finditer(pattern, text))
693
-
694
- instances = []
695
- for i in range(len(parsed)):
696
- # Prepare instance.
697
- instance = {}
698
-
699
- if phrase_centric:
700
- bbox_bins = [int(parsed[i].group(j)) for j in range(2, 6)]
701
- else:
702
- bbox_bins = [int(parsed[i].group(j)) for j in range(1, 5)]
703
- instance['bbox'] = self.box_quantizer.dequantize(
704
- boxes=torch.tensor(bbox_bins),
705
- size=image_size
706
- ).tolist()
707
-
708
- if phrase_centric:
709
- instance['cat_name'] = parsed[i].group(1).lower().strip()
710
- else:
711
- instance['cat_name'] = parsed[i].group(5).lower().strip()
712
- instances.append(instance)
713
-
714
- return instances
715
-
716
- def parse_ocr_from_text_and_spans(self,
717
- text,
718
- pattern,
719
- image_size,
720
- area_threshold=-1.0,
721
- ):
722
- bboxes = []
723
- labels = []
724
- text = text.replace('<s>', '')
725
- # ocr with regions
726
- parsed = re.findall(pattern, text)
727
- instances = []
728
- image_width, image_height = image_size
729
-
730
- for ocr_line in parsed:
731
- ocr_content = ocr_line[0]
732
- quad_box = ocr_line[1:]
733
- quad_box = [int(i) for i in quad_box]
734
- quad_box = self.coordinates_quantizer.dequantize(
735
- torch.tensor(np.array(quad_box).reshape(-1, 2)),
736
- size=image_size
737
- ).reshape(-1).tolist()
738
-
739
- if area_threshold > 0:
740
- x_coords = [i for i in quad_box[0::2]]
741
- y_coords = [i for i in quad_box[1::2]]
742
-
743
- # apply the Shoelace formula
744
- area = 0.5 * abs(sum(x_coords[i] * y_coords[i + 1] - x_coords[i + 1] * y_coords[i] for i in range(4 - 1)))
745
-
746
- if area < (image_width * image_height) * area_threshold:
747
- continue
748
-
749
- bboxes.append(quad_box)
750
- labels.append(ocr_content)
751
- instances.append({
752
- 'quad_box': quad_box,
753
- 'text': ocr_content,
754
- })
755
- return instances
756
-
757
- def parse_phrase_grounding_from_text_and_spans(self, text, pattern, image_size):
758
- # ignore <s> </s> and <pad>
759
- cur_span = 0
760
- if text.startswith('<s>'):
761
- cur_span += 3
762
-
763
- text = text.replace('<s>', '')
764
- text = text.replace('</s>', '')
765
- text = text.replace('<pad>', '')
766
-
767
- pattern = r"([^<]+(?:<loc_\d+>){4,})"
768
- phrases = re.findall(pattern, text)
769
-
770
- # pattern should be text pattern and od pattern
771
- pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
772
- box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
773
-
774
- instances = []
775
- for pharse_text in phrases:
776
- phrase_text_strip = pharse_text.replace('<ground>', '', 1)
777
- phrase_text_strip = pharse_text.replace('<obj>', '', 1)
778
-
779
- if phrase_text_strip == '':
780
- cur_span += len(pharse_text)
781
- continue
782
-
783
- # Prepare instance.
784
- instance = {}
785
-
786
- # parse phrase, get string
787
- phrase = re.search(pattern, phrase_text_strip)
788
- if phrase is None:
789
- cur_span += len(pharse_text)
790
- continue
791
-
792
- # parse bboxes by box_pattern
793
- bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
794
- if len(bboxes_parsed) == 0:
795
- cur_span += len(pharse_text)
796
- continue
797
-
798
- phrase = phrase.group()
799
- # remove leading and trailing spaces
800
- phrase = phrase.strip()
801
-
802
- if phrase in self.black_list_of_phrase_grounding:
803
- cur_span += len(pharse_text)
804
- continue
805
-
806
- # a list of list
807
- bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
808
- instance['bbox'] = self.box_quantizer.dequantize(
809
- boxes=torch.tensor(bbox_bins),
810
- size=image_size
811
- ).tolist()
812
-
813
- # exclude non-ascii characters
814
- phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
815
- instance['cat_name'] = phrase
816
-
817
- instances.append(instance)
818
-
819
- return instances
820
-
821
- def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
- # temporary parse solution, split by '.'
823
- # ignore <s> </s> and <pad>
824
-
825
- text = text.replace('<s>', '')
826
- text = text.replace('</s>', '')
827
- text = text.replace('<pad>', '')
828
-
829
- if allow_empty_phrase:
830
- pattern = rf"(?:(?:<loc_\d+>){{4,}})"
831
- else:
832
- pattern = r"([^<]+(?:<loc_\d+>){4,})"
833
- phrases = re.findall(pattern, text)
834
-
835
- # pattern should be text pattern and od pattern
836
- pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_)'
837
- box_pattern = r'<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>'
838
-
839
- instances = []
840
- for pharse_text in phrases:
841
- phrase_text_strip = pharse_text.replace('<ground>', '', 1)
842
- phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
-
844
- if phrase_text_strip == '' and not allow_empty_phrase:
845
- continue
846
-
847
- # parse phrase, get string
848
- phrase = re.search(pattern, phrase_text_strip)
849
- if phrase is None:
850
- continue
851
-
852
- phrase = phrase.group()
853
- # remove leading and trailing spaces
854
- phrase = phrase.strip()
855
-
856
- # parse bboxes by box_pattern
857
- bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
- if len(bboxes_parsed) == 0:
859
- continue
860
-
861
- # a list of list
862
- bbox_bins = [[int(_bboxes_parsed.group(j)) for j in range(1, 5)] for _bboxes_parsed in bboxes_parsed]
863
-
864
- bboxes = self.box_quantizer.dequantize(
865
- boxes=torch.tensor(bbox_bins),
866
- size=image_size
867
- ).tolist()
868
-
869
- phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
- for _bboxes in bboxes:
871
- # Prepare instance.
872
- instance = {}
873
- instance['bbox'] = _bboxes
874
- # exclude non-ascii characters
875
- instance['cat_name'] = phrase
876
- instances.append(instance)
877
-
878
- return instances
879
-
880
- def parse_description_with_polygons_from_text_and_spans(self, text, pattern, image_size,
881
- allow_empty_phrase=False,
882
- polygon_sep_token='<sep>',
883
- polygon_start_token='<poly>',
884
- polygon_end_token='</poly>',
885
- with_box_at_start=False,
886
- ):
887
-
888
- # ref_seg format: '<expression><x1><y1><x2><y2><><><sep><><><><>'
889
- # ignore <s> </s> and <pad>
890
-
891
- text = text.replace('<s>', '')
892
- text = text.replace('</s>', '')
893
- text = text.replace('<pad>', '')
894
-
895
- if allow_empty_phrase:
896
- pattern = rf"(?:(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
897
- else:
898
- # [^<]+: This part matches one or more characters that are not the < symbol.
899
- # The ^ inside the square brackets [] is a negation, meaning it matches anything except <.
900
- #
901
- pattern = rf"([^<]+(?:<loc_\d+>|{re.escape(polygon_sep_token)}|{re.escape(polygon_start_token)}|{re.escape(polygon_end_token)}){{4,}})"
902
- phrases = re.findall(pattern, text)
903
-
904
- phrase_string_pattern = r'^\s*(.*?)(?=<od>|</od>|<box>|</box>|<bbox>|</bbox>|<loc_|<poly>)'
905
- box_pattern = rf'((?:<loc_\d+>)+)(?:{re.escape(polygon_sep_token)}|$)'
906
-
907
- # one polygons instance is separated by polygon_start_token and polygon_end_token
908
- polygons_instance_pattern = rf'{re.escape(polygon_start_token)}(.*?){re.escape(polygon_end_token)}'
909
-
910
- instances = []
911
- for phrase_text in phrases:
912
-
913
- # exclude loc_\d+>
914
- # need to get span if want to include category score
915
- phrase_text_strip = re.sub(r'^loc_\d+>', '', phrase_text, count=1)
916
-
917
- # phrase = phrase.replace('<poly>', '')
918
- # phrase = phrase.replace('poly>', '')
919
-
920
- if phrase_text_strip == '' and not allow_empty_phrase:
921
- continue
922
-
923
-
924
- # parse phrase, get string
925
- phrase = re.search(phrase_string_pattern, phrase_text_strip)
926
- if phrase is None:
927
- continue
928
- phrase = phrase.group()
929
- # remove leading and trailing spaces
930
- phrase = phrase.strip()
931
-
932
- # parse bboxes by box_pattern
933
-
934
- # split by polygon_start_token and polygon_end_token first using polygons_instance_pattern
935
- if polygon_start_token in phrase_text and polygon_end_token in phrase_text:
936
- polygons_instances_parsed = list(re.finditer(polygons_instance_pattern, phrase_text))
937
- else:
938
- polygons_instances_parsed = [phrase_text]
939
-
940
- for _polygons_instances_parsed in polygons_instances_parsed:
941
- # Prepare instance.
942
- instance = {}
943
-
944
- # polygons_parsed= list(re.finditer(box_pattern, phrase_text))
945
- if isinstance(_polygons_instances_parsed, str):
946
- polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed))
947
- else:
948
- polygons_parsed= list(re.finditer(box_pattern, _polygons_instances_parsed.group(1)))
949
- if len(polygons_parsed) == 0:
950
- continue
951
-
952
- # a list of list (polygon)
953
- bbox = []
954
- polygons = []
955
- for _polygon_parsed in polygons_parsed:
956
- # group 1: whole <loc_\d+>...</loc_\d+>
957
- _polygon = _polygon_parsed.group(1)
958
- # parse into list of int
959
- _polygon = [int(_loc_parsed.group(1)) for _loc_parsed in re.finditer(r'<loc_(\d+)>', _polygon)]
960
- if with_box_at_start and len(bbox) == 0:
961
- if len(_polygon) > 4:
962
- # no valid bbox prediction
963
- bbox = _polygon[:4]
964
- _polygon = _polygon[4:]
965
- else:
966
- bbox = [0, 0, 0, 0]
967
- # abandon last element if is not paired
968
- if len(_polygon) % 2 == 1:
969
- _polygon = _polygon[:-1]
970
-
971
- # reshape into (n, 2)
972
- _polygon = self.coordinates_quantizer.dequantize(
973
- torch.tensor(np.array(_polygon).reshape(-1, 2)),
974
- size=image_size
975
- ).reshape(-1).tolist()
976
- # reshape back
977
- polygons.append(_polygon)
978
-
979
- instance['cat_name'] = phrase
980
- instance['polygons'] = polygons
981
- if len(bbox) != 0:
982
- instance['bbox'] = self.box_quantizer.dequantize(
983
- boxes=torch.tensor([bbox]),
984
- size=image_size
985
- ).tolist()[0]
986
-
987
- instances.append(instance)
988
-
989
- return instances
990
-
991
- def __call__(
992
- self,
993
- text=None,
994
- image_size=None,
995
- parse_tasks=None,
996
- ):
997
- """
998
- Args:
999
- text: model outputs
1000
- image_size: (width, height)
1001
- parse_tasks: a list of tasks to parse, if None, parse all tasks.
1002
-
1003
- """
1004
- if parse_tasks is not None:
1005
- if isinstance(parse_tasks, str):
1006
- parse_tasks = [parse_tasks]
1007
- for _parse_task in parse_tasks:
1008
- assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
-
1010
- # sequence or text should be provided
1011
- assert text is not None, 'text should be provided'
1012
-
1013
- parsed_dict = {
1014
- 'text': text
1015
- }
1016
-
1017
- for task in self.parse_tasks:
1018
- if parse_tasks is not None and task not in parse_tasks:
1019
- continue
1020
-
1021
- pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1022
-
1023
- if task == 'ocr':
1024
- instances = self.parse_ocr_from_text_and_spans(
1025
- text,
1026
- pattern=pattern,
1027
- image_size=image_size,
1028
- area_threshold=self.parse_tasks_configs[task].get('AREA_THRESHOLD', 0.01),
1029
- )
1030
- parsed_dict['ocr'] = instances
1031
- elif task == 'phrase_grounding':
1032
- instances = self.parse_phrase_grounding_from_text_and_spans(
1033
- text,
1034
- pattern=pattern,
1035
- image_size=image_size,
1036
- )
1037
- parsed_dict['phrase_grounding'] = instances
1038
- elif task == 'pure_text':
1039
- parsed_dict['pure_text'] = text
1040
- elif task == 'description_with_bboxes':
1041
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
- text,
1043
- pattern=pattern,
1044
- image_size=image_size,
1045
- )
1046
- parsed_dict['description_with_bboxes'] = instances
1047
- elif task == 'description_with_polygons':
1048
- instances = self.parse_description_with_polygons_from_text_and_spans(
1049
- text,
1050
- pattern=pattern,
1051
- image_size=image_size,
1052
- )
1053
- parsed_dict['description_with_polygons'] = instances
1054
- elif task == 'polygons':
1055
- instances = self.parse_description_with_polygons_from_text_and_spans(
1056
- text,
1057
- pattern=pattern,
1058
- image_size=image_size,
1059
- allow_empty_phrase=True,
1060
- )
1061
- parsed_dict['polygons'] = instances
1062
- elif task == 'bboxes':
1063
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1064
- text,
1065
- pattern=pattern,
1066
- image_size=image_size,
1067
- allow_empty_phrase=True,
1068
- )
1069
- parsed_dict['bboxes'] = instances
1070
- elif task == 'description_with_bboxes_or_polygons':
1071
- if '<poly>' in text:
1072
- # only support either polygons or bboxes, not both at the same time
1073
- instances = self.parse_description_with_polygons_from_text_and_spans(
1074
- text,
1075
- pattern=pattern,
1076
- image_size=image_size,
1077
- )
1078
- else:
1079
- instances = self.parse_description_with_bboxes_from_text_and_spans(
1080
- text,
1081
- pattern=pattern,
1082
- image_size=image_size,
1083
- )
1084
- parsed_dict['description_with_bboxes_or_polygons'] = instances
1085
- else:
1086
- raise ValueError("task {} is not supported".format(task))
1087
-
1088
- return parsed_dict