omarelsayeed commited on
Commit
1847114
1 Parent(s): 878a3fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -24
app.py CHANGED
@@ -11,34 +11,77 @@ from transformers import LayoutLMv3ForTokenClassification
11
  from transformers import AutoProcessor
12
  from transformers import AutoModelForTokenClassification
13
 
14
- reading_order_model = AutoModelForTokenClassification.from_pretrained("omarelsayeed/LayoutReader85Large")
 
15
  processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
16
  apply_ocr=False)
17
 
18
- def predict_reading_order(boxes,image_path):
19
- words = ["<unk>"]*len(boxes)
20
- print(boxes)
21
- encoding = processor(image_path , text = words
22
- , boxes=boxes
23
- ,return_tensors="pt" ,
24
- return_offsets_mapping=True)
25
- offset_mapping = encoding.pop('offset_mapping')
26
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
- for k,v in encoding.items():
28
- encoding[k] = v.to(device)
29
- outputs = reading_order_model(**encoding)
30
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
31
- token_boxes = encoding.bbox.squeeze().tolist()
32
- is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
33
- # true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
34
- predictions = predictions[1:-1]
35
- return predictions
36
-
37
- def get_orders(image_path, boxes):
38
- b = scale_and_normalize_boxes(boxes)
39
- orders = predict_reading_order(b, image_path)
40
- return orders
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
44
  model = RTDETR(model_dir)
 
11
  from transformers import AutoProcessor
12
  from transformers import AutoModelForTokenClassification
13
 
14
+ finetuned_fully = LayoutLMv3ForTokenClassification.from_pretrained("omarelsayeed/finetuned_pretrained_model")
15
+
16
  processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base",
17
  apply_ocr=False)
18
 
19
+ MAX_LEN = 70
20
+ CLS_TOKEN_ID = 0
21
+ UNK_TOKEN_ID = 3
22
+ EOS_TOKEN_ID = 2
23
+ import torch
24
+ def boxes2inputs(boxes):
25
+ bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
26
+ input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
27
+ attention_mask = [1] + [1] * len(boxes) + [1]
28
+ return {
29
+ "bbox": torch.tensor([bbox]),
30
+ "attention_mask": torch.tensor([attention_mask]),
31
+ "input_ids": torch.tensor([input_ids]),
32
+ }
33
+ def parse_logits(logits: torch.Tensor, length):
34
+ """
35
+ parse logits to orders
 
 
 
 
 
 
36
 
37
+ :param logits: logits from model
38
+ :param length: input length
39
+ :return: orders
40
+ """
41
+ logits = logits[1 : length + 1, :length]
42
+ orders = logits.argsort(descending=False).tolist()
43
+ ret = [o.pop() for o in orders]
44
+ while True:
45
+ order_to_idxes = defaultdict(list)
46
+ for idx, order in enumerate(ret):
47
+ order_to_idxes[order].append(idx)
48
+ # filter idxes len > 1
49
+ order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
50
+ if not order_to_idxes:
51
+ break
52
+ # filter
53
+ for order, idxes in order_to_idxes.items():
54
+ # find original logits of idxes
55
+ idxes_to_logit = {}
56
+ for idx in idxes:
57
+ idxes_to_logit[idx] = logits[idx, order]
58
+ idxes_to_logit = sorted(
59
+ idxes_to_logit.items(), key=lambda x: x[1], reverse=True
60
+ )
61
+ # keep the highest logit as order, set others to next candidate
62
+ for idx, _ in idxes_to_logit[1:]:
63
+ ret[idx] = orders[idx].pop()
64
+ return ret
65
+
66
+
67
+ def prepare_inputs(
68
+ inputs, model
69
+ ):
70
+ ret = {}
71
+ for k, v in inputs.items():
72
+ v = v.to(model.device)
73
+ if torch.is_floating_point(v):
74
+ v = v.to(model.dtype)
75
+ ret[k] = v
76
+ return ret
77
+
78
+
79
+ def get_orders(image_path , boxes):
80
+ inputs = boxes2inputs(boxes)
81
+ inputs = prepare_inputs(inputs, finetuned_fully)
82
+ logits = finetuned_fully(**inputs).logits.cpu().squeeze(0)
83
+ predictions = parse_logits(logits, len(boxes))
84
+ return predictions
85
 
86
  model_dir = snapshot_download("omarelsayeed/DETR-ARABIC-DOCUMENT-LAYOUT-ANALYSIS") + "/rtdetr_1024_crops.pt"
87
  model = RTDETR(model_dir)