nielsr HF staff commited on
Commit
c3ee9a5
·
1 Parent(s): ac846af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import cv2
4
+ from ditod import add_vit_config
5
+ from detectron2.config import get_cfg
6
+ from detectron2.utils.visualizer import ColorMode, Visualizer
7
+ from detectron2.data import MetadataCatalog
8
+ from detectron2.engine import DefaultPredictor
9
+
10
+
11
+ # Step 1: instantiate config
12
+ cfg = get_cfg()
13
+ add_vit_config(cfg)
14
+ cfg.merge_from_file("cascade_dit_base.yaml")
15
+
16
+ # Step 2: add model weights URL to config
17
+ cfg.MODEL.WEIGHTS = https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_mrcnn.pth
18
+
19
+ # Step 3: set device
20
+ # TODO also support GPU
21
+ cfg.MODEL.DEVICE='cpu'
22
+
23
+ # Step 4: define model
24
+ predictor = DefaultPredictor(cfg)
25
+
26
+
27
+ def analyze_image(img):
28
+ md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
29
+ if cfg.DATASETS.TEST[0]=='icdar2019_test':
30
+ md.set(thing_classes=["table"])
31
+ else:
32
+ md.set(thing_classes=["text","title","list","table","figure"])
33
+
34
+ output = predictor(img)["instances"]
35
+ v = Visualizer(img[:, :, ::-1],
36
+ md,
37
+ scale=1.0,
38
+ instance_mode=ColorMode.SEGMENTATION)
39
+ result = v.draw_instance_predictions(output.to("cpu"))
40
+ result_image = result.get_image()[:, :, ::-1]
41
+
42
+ return result_image
43
+
44
+ title = "Interactive demo: Document Layout Analysis with DiT"
45
+ description = "This is a demo for Microsoft's Document Image Transformer (DiT)."
46
+ examples =[['document.png']]
47
+
48
+ iface = gr.Interface(fn=analyze_image,
49
+ inputs=gr.inputs.Image(type="numpy"),
50
+ outputs=gr.outputs.Image(type="numpy", label="analyzed image"),
51
+ title=title,
52
+ description=description,
53
+ article=article,
54
+ examples=examples,
55
+ enable_queue=True)
56
+ iface.launch(debug=True)