wybxc commited on
Commit
2a54612
·
verified ·
1 Parent(s): 8ab23ff

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +131 -0
inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import onnx
3
+ import onnxruntime as ort
4
+ import cv2
5
+ from huggingface_hub import hf_hub_download
6
+ import numpy as np
7
+
8
+ # Download the model from the Hugging Face Hub
9
+ model = hf_hub_download(
10
+ repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx",
11
+ filename="doclayout_yolo_docstructbench_imgsz1024.onnx",
12
+ )
13
+ model = onnx.load(model)
14
+ metadata = {prop.key: prop.value for prop in model.metadata_props}
15
+
16
+ names = ast.literal_eval(metadata["names"])
17
+ stride = ast.literal_eval(metadata["stride"])
18
+
19
+ # Load the model with ONNX Runtime
20
+ session = ort.InferenceSession(model.SerializeToString())
21
+
22
+
23
+ def resize_and_pad_image(image, new_shape, stride=32):
24
+ """
25
+ Resize and pad the image to the specified size, ensuring dimensions are multiples of stride.
26
+
27
+ Parameters:
28
+ - image: Input image
29
+ - new_shape: Target size (integer or (height, width) tuple)
30
+ - stride: Padding alignment stride, default 32
31
+
32
+ Returns:
33
+ - Processed image
34
+ """
35
+ if isinstance(new_shape, int):
36
+ new_shape = (new_shape, new_shape)
37
+
38
+ h, w = image.shape[:2]
39
+ new_h, new_w = new_shape
40
+
41
+ # Calculate scaling ratio
42
+ r = min(new_h / h, new_w / w)
43
+ resized_h, resized_w = int(round(h * r)), int(round(w * r))
44
+
45
+ # Resize image
46
+ image = cv2.resize(image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR)
47
+
48
+ # Calculate padding size and align to stride multiple
49
+ pad_w = (new_w - resized_w) % stride
50
+ pad_h = (new_h - resized_h) % stride
51
+ top, bottom = pad_h // 2, pad_h - pad_h // 2
52
+ left, right = pad_w // 2, pad_w - pad_w // 2
53
+
54
+ # Add padding
55
+ image = cv2.copyMakeBorder(
56
+ image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
57
+ )
58
+
59
+ return image
60
+
61
+
62
+ class YoloResult:
63
+ def __init__(self, boxes, names):
64
+ self.boxes = [YoloBox(data=d) for d in boxes]
65
+ self.names = names
66
+
67
+
68
+ class YoloBox:
69
+ def __init__(self, data):
70
+ self.xyxy = data[:4]
71
+ self.conf = data[-2]
72
+ self.cls = data[-1]
73
+
74
+
75
+ def inference(image):
76
+ """
77
+ Run inference on the input image.
78
+
79
+ Parameters:
80
+ - image: Input image, HWC format and RGB order
81
+
82
+ Returns:
83
+ - YoloResult object containing the predicted boxes and class names
84
+ """
85
+
86
+ # Preprocess image
87
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
88
+ pix = resize_and_pad_image(image, new_shape=int(image.shape[0] / stride) * stride)
89
+ pix = np.transpose(pix, (2, 0, 1)) # CHW
90
+ pix = np.expand_dims(pix, axis=0) # BCHW
91
+ pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1]
92
+
93
+ # Run inference
94
+ preds = session.run(None, {"images": pix})[0]
95
+
96
+ # Postprocess predictions
97
+ preds = preds[preds[..., 4] > 0.25]
98
+ return YoloResult(boxes=preds, names=names)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import sys
103
+ import matplotlib.pyplot as plt
104
+
105
+ image = sys.argv[1]
106
+ image = cv2.imread(image)
107
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108
+
109
+ layout = inference(image)
110
+
111
+ bitmap = np.ones(image.shape[:2], dtype=np.uint8)
112
+ h, w = bitmap.shape
113
+ vcls = ["abandon", "figure", "table", "isolate_formula", "formula_caption"]
114
+ for i, d in enumerate(layout.boxes):
115
+ x0, y0, x1, y1 = d.xyxy.squeeze()
116
+ x0, y0, x1, y1 = (
117
+ np.clip(int(x0 - 1), 0, w - 1),
118
+ np.clip(int(h - y1 - 1), 0, h - 1),
119
+ np.clip(int(x1 + 1), 0, w - 1),
120
+ np.clip(int(h - y0 + 1), 0, h - 1),
121
+ )
122
+ if layout.names[int(d.cls)] in vcls:
123
+ bitmap[y0:y1, x0:x1] = 0
124
+ else:
125
+ bitmap[y0:y1, x0:x1] = i + 2
126
+ bitmap = bitmap[::-1, :]
127
+
128
+ fig, ax = plt.subplots(1, 2, figsize=(10, 6))
129
+ ax[0].imshow(image)
130
+ ax[1].imshow(bitmap)
131
+ plt.show()