omarelsayeed
commited on
Commit
•
1847114
1
Parent(s):
878a3fa
Update app.py
Browse files
app.py
CHANGED
@@ -11,34 +11,77 @@ from transformers import LayoutLMv3ForTokenClassification
|
|
11 |
from transformers import AutoProcessor
|
12 |
from transformers import AutoModelForTokenClassification
|
13 |
|
14 |
-
|
|
|
15 |
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
|
16 |
apply_ocr=False)
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
return predictions
|
36 |
-
|
37 |
-
def get_orders(image_path, boxes):
|
38 |
-
b = scale_and_normalize_boxes(boxes)
|
39 |
-
orders = predict_reading_order(b, image_path)
|
40 |
-
return orders
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
44 |
model = RTDETR(model_dir)
|
|
|
11 |
from transformers import AutoProcessor
|
12 |
from transformers import AutoModelForTokenClassification
|
13 |
|
14 |
+
finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/finetuned_pretrained_model")
|
15 |
+
|
16 |
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
|
17 |
apply_ocr=False)
|
18 |
|
19 |
+
MAX_LEN = 70
|
20 |
+
CLS_TOKEN_ID = 0
|
21 |
+
UNK_TOKEN_ID = 3
|
22 |
+
EOS_TOKEN_ID = 2
|
23 |
+
import torch
|
24 |
+
def boxes2inputs(boxes):
|
25 |
+
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
|
26 |
+
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
|
27 |
+
attention_mask = [1] + [1] * len(boxes) + [1]
|
28 |
+
return {
|
29 |
+
"bbox": torch.tensor([bbox]),
|
30 |
+
"attention_mask": torch.tensor([attention_mask]),
|
31 |
+
"input_ids": torch.tensor([input_ids]),
|
32 |
+
}
|
33 |
+
def parse_logits(logits: torch.Tensor, length):
|
34 |
+
"""
|
35 |
+
parse logits to orders
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
:param logits: logits from model
|
38 |
+
:param length: input length
|
39 |
+
:return: orders
|
40 |
+
"""
|
41 |
+
logits = logits[1 : length + 1, :length]
|
42 |
+
orders = logits.argsort(descending=False).tolist()
|
43 |
+
ret = [o.pop() for o in orders]
|
44 |
+
while True:
|
45 |
+
order_to_idxes = defaultdict(list)
|
46 |
+
for idx, order in enumerate(ret):
|
47 |
+
order_to_idxes[order].append(idx)
|
48 |
+
# filter idxes len > 1
|
49 |
+
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
|
50 |
+
if not order_to_idxes:
|
51 |
+
break
|
52 |
+
# filter
|
53 |
+
for order, idxes in order_to_idxes.items():
|
54 |
+
# find original logits of idxes
|
55 |
+
idxes_to_logit = {}
|
56 |
+
for idx in idxes:
|
57 |
+
idxes_to_logit[idx] = logits[idx, order]
|
58 |
+
idxes_to_logit = sorted(
|
59 |
+
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
|
60 |
+
)
|
61 |
+
# keep the highest logit as order, set others to next candidate
|
62 |
+
for idx, _ in idxes_to_logit[1:]:
|
63 |
+
ret[idx] = orders[idx].pop()
|
64 |
+
return ret
|
65 |
+
|
66 |
+
|
67 |
+
def prepare_inputs(
|
68 |
+
inputs, model
|
69 |
+
):
|
70 |
+
ret = {}
|
71 |
+
for k, v in inputs.items():
|
72 |
+
v = v.to(model.device)
|
73 |
+
if torch.is_floating_point(v):
|
74 |
+
v = v.to(model.dtype)
|
75 |
+
ret[k] = v
|
76 |
+
return ret
|
77 |
+
|
78 |
+
|
79 |
+
def get_orders(image_path , boxes):
|
80 |
+
inputs = boxes2inputs(boxes)
|
81 |
+
inputs = prepare_inputs(inputs, finetuned_fully)
|
82 |
+
logits = finetuned_fully(**inputs).logits.cpu().squeeze(0)
|
83 |
+
predictions = parse_logits(logits, len(boxes))
|
84 |
+
return predictions
|
85 |
|
86 |
model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
|
87 |
model = RTDETR(model_dir)
|