AAAAAAyq commited on
Commit
87c6f54
·
1 Parent(s): 4d26566

Update application file

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -4,11 +4,30 @@ import numpy as np
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import io
 
7
  # import cv2
8
 
9
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
10
 
11
- def show_mask(annotation, ax, random_color=False, bbox=None, points=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if random_color : # random mask color
13
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
14
  else:
@@ -28,28 +47,27 @@ def show_mask(annotation, ax, random_color=False, bbox=None, points=None):
28
  ax.imshow(mask_image)
29
  return mask_image
30
 
31
- def post_process(annotations, image, mask_random_color=False, bbox=None, points=None):
32
- # image = cv2.imread(image_path)
33
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
34
  plt.figure(figsize=(10, 10))
35
  plt.imshow(image)
36
  for i, mask in enumerate(annotations):
37
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
38
  plt.axis('off')
39
-
40
  # create a BytesIO object
41
  buf = io.BytesIO()
42
 
43
  # save plot to buf
44
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
45
- # plt.savefig('buffer/tmp.png', bbox_inches='tight', pad_inches=0.0)
46
 
47
  # use PIL to open the image
48
  img = Image.open(buf)
49
 
 
 
 
50
  # don't forget to close the buffer
51
  buf.close()
52
- return img
53
 
54
 
55
  # def show_mask(annotation, ax, random_color=False):
@@ -77,10 +95,15 @@ def post_process(annotations, image, mask_random_color=False, bbox=None, points=
77
  # post_process(results[0].masks, Image.open("../data/cake.png"))
78
 
79
  def predict(inp):
80
- results = model(inp, device='0', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
81
- pil_image = post_process(results[0].masks, inp)
 
82
  return pil_image
83
 
 
 
 
 
84
 
85
  demo = gr.Interface(fn=predict,
86
  inputs=gr.inputs.Image(type='pil'),
 
4
  import matplotlib.pyplot as plt
5
  import gradio as gr
6
  import io
7
+ import torch
8
  # import cv2
9
 
10
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
11
 
12
+ def format_results(result,filter = 0):
13
+ annotations = []
14
+ n = len(result.masks.data)
15
+ for i in range(n):
16
+ annotation = {}
17
+ mask = result.masks.data[i] == 1.0
18
+
19
+
20
+ if torch.sum(mask) < filter:
21
+ continue
22
+ annotation['id'] = i
23
+ annotation['segmentation'] = mask.cpu().numpy()
24
+ annotation['bbox'] = result.boxes.data[i]
25
+ annotation['score'] = result.boxes.conf[i]
26
+ annotation['area'] = annotation['segmentation'].sum()
27
+ annotations.append(annotation)
28
+ return annotations
29
+
30
+ def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
31
  if random_color : # random mask color
32
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
33
  else:
 
47
  ax.imshow(mask_image)
48
  return mask_image
49
 
50
+ def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
 
 
51
  plt.figure(figsize=(10, 10))
52
  plt.imshow(image)
53
  for i, mask in enumerate(annotations):
54
  show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
55
  plt.axis('off')
 
56
  # create a BytesIO object
57
  buf = io.BytesIO()
58
 
59
  # save plot to buf
60
  plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
 
61
 
62
  # use PIL to open the image
63
  img = Image.open(buf)
64
 
65
+ # copy the image data
66
+ img_copy = img.copy()
67
+
68
  # don't forget to close the buffer
69
  buf.close()
70
+ return img_copy
71
 
72
 
73
  # def show_mask(annotation, ax, random_color=False):
 
95
  # post_process(results[0].masks, Image.open("../data/cake.png"))
96
 
97
  def predict(inp):
98
+ results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
99
+ results = format_results(results[0], 100)
100
+ pil_image = post_process(annotations=results, image=inp)
101
  return pil_image
102
 
103
+ # inp = 'assets/sa_192.jpg'
104
+ # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
105
+ # results = format_results(results[0], 100)
106
+ # post_process(annotations=results, image_path=inp)
107
 
108
  demo = gr.Interface(fn=predict,
109
  inputs=gr.inputs.Image(type='pil'),