AAAAAAyq commited on
Commit
c987532
1 Parent(s): b3d5599

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -83
app.py CHANGED
@@ -1,83 +1,171 @@
1
- from ultralytics import YOLO
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import gradio as gr
5
- import torch
6
-
7
- model = YOLO('checkpoints/FastSAM.pt') # load a custom model
8
-
9
- def format_results(result,filter = 0):
10
- annotations = []
11
- n = len(result.masks.data)
12
- for i in range(n):
13
- annotation = {}
14
- mask = result.masks.data[i] == 1.0
15
-
16
- if torch.sum(mask) < filter:
17
- continue
18
- annotation['id'] = i
19
- annotation['segmentation'] = mask.cpu().numpy()
20
- annotation['bbox'] = result.boxes.data[i]
21
- annotation['score'] = result.boxes.conf[i]
22
- annotation['area'] = annotation['segmentation'].sum()
23
- annotations.append(annotation)
24
- return annotations
25
-
26
- def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
27
- if random_color : # random mask color
28
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
29
- else:
30
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
31
- if type(annotation) == dict:
32
- annotation = annotation['segmentation']
33
- mask = annotation
34
- h, w = mask.shape[-2:]
35
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
36
- # draw box
37
- if bbox is not None:
38
- x1, y1, x2, y2 = bbox
39
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
40
- # draw point
41
- if points is not None:
42
- ax.scatter([point[0] for point in points], [point[1] for point in points], s=10, c='g')
43
- ax.imshow(mask_image)
44
- return mask_image
45
-
46
- def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
47
- fig = plt.figure(figsize=(10, 10))
48
- plt.imshow(image)
49
- for i, mask in enumerate(annotations):
50
- show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
51
- plt.axis('off')
52
-
53
- plt.tight_layout()
54
- return fig
55
-
56
-
57
- # post_process(results[0].masks, Image.open("../data/cake.png"))
58
-
59
- def predict(input, input_size):
60
- input_size = int(input_size) # 确保 imgsz 是整数
61
- results = model(input, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
62
- results = format_results(results[0], 100)
63
- results.sort(key=lambda x: x['area'], reverse=True)
64
- pil_image = post_process(annotations=results, image=input)
65
- return pil_image
66
-
67
- # inp = 'assets/sa_192.jpg'
68
- # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
69
- # results = format_results(results[0], 100)
70
- # post_process(annotations=results, image_path=inp)
71
-
72
- demo = gr.Interface(fn=predict,
73
- inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024], default=1024)],
74
- outputs=['plot'],
75
- examples=[["assets/sa_8776.jpg", 1024]],
76
- # ["assets/sa_1309.jpg", 1024]],
77
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
78
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
79
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
80
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
81
- )
82
-
83
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ import cv2
6
+ import torch
7
+
8
+ model = YOLO('checkpoints/FastSAM.pt') # load a custom model
9
+
10
+
11
+ def fast_process(annotations, image):
12
+ fig = plt.figure(figsize=(10, 10))
13
+ plt.imshow(image)
14
+ #original_h = image.shape[0]
15
+ #original_w = image.shape[1]
16
+ #for i, mask in enumerate(annotations):
17
+ # mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
18
+ # annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
19
+ fast_show_mask(annotations,
20
+ plt.gca())
21
+ #target_height=original_h,
22
+ #target_width=original_w)
23
+ plt.axis('off')
24
+ plt.tight_layout()
25
+ return fig
26
+
27
+
28
+ # CPU post process
29
+ def fast_show_mask(annotation, ax):
30
+ msak_sum = annotation.shape[0]
31
+ height = annotation.shape[1]
32
+ weight = annotation.shape[2]
33
+ # 将annotation 按照面积 排序
34
+ areas = np.sum(annotation, axis=(1, 2))
35
+ sorted_indices = np.argsort(areas)[::1]
36
+ annotation = annotation[sorted_indices]
37
+
38
+ index = (annotation != 0).argmax(axis=0)
39
+ color = np.random.random((msak_sum, 1, 1, 3))
40
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
41
+ visual = np.concatenate([color, transparency], axis=-1)
42
+ mask_image = np.expand_dims(annotation, -1) * visual
43
+
44
+ show = np.zeros((height, weight, 4))
45
+
46
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
47
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
48
+ # 使用向量化索引更新show的值
49
+ show[h_indices, w_indices, :] = mask_image[indices]
50
+
51
+
52
+ #if retinamask == False:
53
+ # show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
54
+ ax.imshow(show)
55
+
56
+
57
+
58
+ # post_process(results[0].masks, Image.open("../data/cake.png"))
59
+
60
+ def predict(input, input_size):
61
+ input_size = int(input_size) # 确保 imgsz 是整数
62
+ results = model(input, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
63
+ pil_image = fast_process(annotations=results[0].masks.data, image=input)
64
+
65
+ return pil_image
66
+
67
+
68
+ # inp = 'assets/sa_192.jpg'
69
+ # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
70
+ # results = format_results(results[0], 100)
71
+ # post_process(annotations=results, image_path=inp)
72
+
73
+ demo = gr.Interface(fn=predict,
74
+ inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024], default=1024)],
75
+ outputs=['plot'],
76
+ examples=[["assets/sa_8776.jpg", 1024]],
77
+ # ["assets/sa_1309.jpg", 1024]],
78
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
79
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
80
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
81
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
82
+ )
83
+
84
+ demo.launch()
85
+ """
86
+
87
+ from ultralytics import YOLO
88
+ import numpy as np
89
+ import matplotlib.pyplot as plt
90
+ import gradio as gr
91
+ import torch
92
+
93
+ model = YOLO('checkpoints/FastSAM.pt') # load a custom model
94
+
95
+ def format_results(result,filter = 0):
96
+ annotations = []
97
+ n = len(result.masks.data)
98
+ for i in range(n):
99
+ annotation = {}
100
+ mask = result.masks.data[i] == 1.0
101
+
102
+ if torch.sum(mask) < filter:
103
+ continue
104
+ annotation['id'] = i
105
+ annotation['segmentation'] = mask.cpu().numpy()
106
+ annotation['bbox'] = result.boxes.data[i]
107
+ annotation['score'] = result.boxes.conf[i]
108
+ annotation['area'] = annotation['segmentation'].sum()
109
+ annotations.append(annotation)
110
+ return annotations
111
+
112
+ def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
113
+ if random_color : # random mask color
114
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
115
+ else:
116
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
117
+ if type(annotation) == dict:
118
+ annotation = annotation['segmentation']
119
+ mask = annotation
120
+ h, w = mask.shape[-2:]
121
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
122
+ # draw box
123
+ if bbox is not None:
124
+ x1, y1, x2, y2 = bbox
125
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
126
+ # draw point
127
+ if points is not None:
128
+ ax.scatter([point[0] for point in points], [point[1] for point in points], s=10, c='g')
129
+ ax.imshow(mask_image)
130
+ return mask_image
131
+
132
+ def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
133
+ fig = plt.figure(figsize=(10, 10))
134
+ plt.imshow(image)
135
+ for i, mask in enumerate(annotations):
136
+ show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
137
+ plt.axis('off')
138
+
139
+ plt.tight_layout()
140
+ return fig
141
+
142
+
143
+ # post_process(results[0].masks, Image.open("../data/cake.png"))
144
+
145
+ def predict(input, input_size):
146
+ input_size = int(input_size) # 确保 imgsz 是整数
147
+ results = model(input, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
148
+ results = format_results(results[0], 100)
149
+ results.sort(key=lambda x: x['area'], reverse=True)
150
+ pil_image = post_process(annotations=results, image=input)
151
+ return pil_image
152
+
153
+ # inp = 'assets/sa_192.jpg'
154
+ # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
155
+ # results = format_results(results[0], 100)
156
+ # post_process(annotations=results, image_path=inp)
157
+
158
+ demo = gr.Interface(fn=predict,
159
+ inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024], default=1024)],
160
+ outputs=['plot'],
161
+ examples=[["assets/sa_8776.jpg", 1024]],
162
+ # ["assets/sa_1309.jpg", 1024]],
163
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
164
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
165
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
166
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
167
+ )
168
+
169
+ demo.launch()
170
+
171
+ """