Files changed (2) hide show
  1. conditional_detr_utils.py +179 -0
  2. modelling_magi.py +4 -3
conditional_detr_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Conditional DETR model."""
16
+
17
+ from transformers.utils import (
18
+ is_scipy_available,
19
+ is_vision_available,
20
+ logging
21
+ )
22
+
23
+ import torch
24
+ from torch import Tensor, nn
25
+
26
+ if is_scipy_available():
27
+ from scipy.optimize import linear_sum_assignment
28
+
29
+ if is_vision_available():
30
+ from transformers.image_transforms import center_to_corners_format
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
35
+ class ConditionalDetrHungarianMatcher(nn.Module):
36
+ """
37
+ This class computes an assignment between the targets and the predictions of the network.
38
+
39
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
40
+ predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
41
+ un-matched (and thus treated as non-objects).
42
+
43
+ Args:
44
+ class_cost:
45
+ The relative weight of the classification error in the matching cost.
46
+ bbox_cost:
47
+ The relative weight of the L1 error of the bounding box coordinates in the matching cost.
48
+ giou_cost:
49
+ The relative weight of the giou loss of the bounding box in the matching cost.
50
+ """
51
+
52
+ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
53
+ super().__init__()
54
+
55
+ self.class_cost = class_cost
56
+ self.bbox_cost = bbox_cost
57
+ self.giou_cost = giou_cost
58
+ if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
59
+ raise ValueError("All costs of the Matcher can't be 0")
60
+
61
+ @torch.no_grad()
62
+ def forward(self, outputs, targets):
63
+ """
64
+ Args:
65
+ outputs (`dict`):
66
+ A dictionary that contains at least these entries:
67
+ * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
68
+ * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
69
+ targets (`List[dict]`):
70
+ A list of targets (len(targets) = batch_size), where each target is a dict containing:
71
+ * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
72
+ ground-truth
73
+ objects in the target) containing the class labels
74
+ * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
75
+
76
+ Returns:
77
+ `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
78
+ - index_i is the indices of the selected predictions (in order)
79
+ - index_j is the indices of the corresponding selected targets (in order)
80
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
81
+ """
82
+ batch_size, num_queries = outputs["logits"].shape[:2]
83
+
84
+ # We flatten to compute the cost matrices in a batch
85
+ out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
86
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
87
+
88
+ # Also concat the target labels and boxes
89
+ target_ids = torch.cat([v["class_labels"] for v in targets])
90
+ target_bbox = torch.cat([v["boxes"] for v in targets])
91
+
92
+ # Compute the classification cost.
93
+ alpha = 0.25
94
+ gamma = 2.0
95
+ neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
96
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
97
+ class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
98
+
99
+ # Compute the L1 cost between boxes
100
+ bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
101
+
102
+ # Compute the giou cost between boxes
103
+ giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
104
+
105
+ # Final cost matrix
106
+ cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
107
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
108
+
109
+ sizes = [len(v["boxes"]) for v in targets]
110
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
111
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
112
+
113
+
114
+ # Copied from transformers.models.detr.modeling_detr._upcast
115
+ def _upcast(t: Tensor) -> Tensor:
116
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
117
+ if t.is_floating_point():
118
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
119
+ else:
120
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
121
+
122
+
123
+ # Copied from transformers.models.detr.modeling_detr.box_area
124
+ def box_area(boxes: Tensor) -> Tensor:
125
+ """
126
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
127
+
128
+ Args:
129
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
130
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
131
+ < x2` and `0 <= y1 < y2`.
132
+
133
+ Returns:
134
+ `torch.FloatTensor`: a tensor containing the area for each box.
135
+ """
136
+ boxes = _upcast(boxes)
137
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
138
+
139
+
140
+ # Copied from transformers.models.detr.modeling_detr.box_iou
141
+ def box_iou(boxes1, boxes2):
142
+ area1 = box_area(boxes1)
143
+ area2 = box_area(boxes2)
144
+
145
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
146
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
147
+
148
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
149
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
150
+
151
+ union = area1[:, None] + area2 - inter
152
+
153
+ iou = inter / union
154
+ return iou, union
155
+
156
+
157
+ # Copied from transformers.models.detr.modeling_detr.generalized_box_iou
158
+ def generalized_box_iou(boxes1, boxes2):
159
+ """
160
+ Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
161
+
162
+ Returns:
163
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
164
+ """
165
+ # degenerate boxes gives inf / nan results
166
+ # so do an early check
167
+ if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
168
+ raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
169
+ if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
170
+ raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
171
+ iou, union = box_iou(boxes1, boxes2)
172
+
173
+ top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
174
+ bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
175
+
176
+ width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
177
+ area = width_height[:, :, 0] * width_height[:, :, 1]
178
+
179
+ return iou - (area - union) / area
modelling_magi.py CHANGED
@@ -2,15 +2,15 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
2
  from transformers.models.conditional_detr.modeling_conditional_detr import (
3
  ConditionalDetrMLPPredictionHead,
4
  ConditionalDetrModelOutput,
5
- ConditionalDetrHungarianMatcher,
6
  inverse_sigmoid,
7
  )
 
8
  from .configuration_magi import MagiConfig
9
  from .processing_magi import MagiProcessor
10
  from torch import nn
11
  from typing import Optional, List
12
  import torch
13
- from einops import rearrange, repeat, einsum
14
  from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
15
 
16
  class MagiModel(PreTrainedModel):
@@ -498,4 +498,5 @@ class MagiModel(PreTrainedModel):
498
  if apply_sigmoid:
499
  text_character_affinities = text_character_affinities.sigmoid()
500
  affinity_matrices.append(text_character_affinities)
501
- return affinity_matrices
 
 
2
  from transformers.models.conditional_detr.modeling_conditional_detr import (
3
  ConditionalDetrMLPPredictionHead,
4
  ConditionalDetrModelOutput,
 
5
  inverse_sigmoid,
6
  )
7
+ from .conditional_detr_utils import ConditionalDetrHungarianMatcher
8
  from .configuration_magi import MagiConfig
9
  from .processing_magi import MagiProcessor
10
  from torch import nn
11
  from typing import Optional, List
12
  import torch
13
+ from einops import rearrange, repeat
14
  from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
15
 
16
  class MagiModel(PreTrainedModel):
 
498
  if apply_sigmoid:
499
  text_character_affinities = text_character_affinities.sigmoid()
500
  affinity_matrices.append(text_character_affinities)
501
+ return affinity_matrices
502
+