taher30 commited on
Commit
c24777f
·
verified ·
1 Parent(s): 8e4f839

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -6,34 +6,51 @@ from torchvision import transforms
6
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
 
9
  # import segmentation_models_pytorch as smp
10
 
 
 
11
 
12
 
13
  # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
14
- def get_masks(model_type, image):
15
- if model_type.all() == 'vit_h':
 
 
 
16
  sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
17
-
18
- if model_type,all() == 'vit_b':
19
  sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
20
 
21
- if model_type.all() == 'vit_l':
22
  sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
 
 
23
 
 
 
 
 
24
  mask_generator = SamAutomaticMaskGenerator(sam)
25
  masks = mask_generator.generate(image)
 
 
 
26
  for i, mask_data in enumerate(masks):
27
  mask = mask_data['segmentation']
28
  color = colors[i]
29
  composite_image[mask] = (color[:3] * 255).astype(np.uint8) # Apply color to mask
 
30
 
31
  # Combine original image with the composite mask image
32
- overlayed_image = (composite_image * 0.5 + image_cv.squeeze().permute(1, 2, 0).cpu().numpy() * 0.5).astype(np.uint8)
 
33
  return overlayed_image
34
 
35
 
36
 
 
37
  iface = gr.Interface(
38
  fn=get_masks,
39
  inputs=["image", gr.components.Dropdown(choices=['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
@@ -42,5 +59,4 @@ iface = gr.Interface(
42
  description="Upload an image, select a model type, and receive the segmented and classified parts."
43
  )
44
 
45
-
46
  iface.launch()
 
6
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
9
+
10
  # import segmentation_models_pytorch as smp
11
 
12
+ ##set the device to cuda for sam model
13
+ # device = torch.device('cuda')
14
 
15
 
16
  # image= cv2.imread('image_4.png', cv2.IMREAD_COLOR)
17
+ def get_masks( image, model_type):
18
+ print(image)
19
+ # image_pil = Image.fromarray(image.astype('uint8'), 'RGB')
20
+ # print(image_pil)
21
+ if model_type == 'vit_h':
22
  sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
23
+ if model_type == 'vit_b':
 
24
  sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
25
 
26
+ if model_type == 'vit_l':
27
  sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
28
+ else:
29
+ sam= sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
30
 
31
+ # print(image.shape)
32
+ #set the device to cuda for sam model
33
+ # sam.to(device= device)
34
+
35
  mask_generator = SamAutomaticMaskGenerator(sam)
36
  masks = mask_generator.generate(image)
37
+ composite_image = np.zeros_like(image)
38
+ colors = plt.cm.jet(np.linspace(0, 1, len(masks))) # Generate distinct colors
39
+
40
  for i, mask_data in enumerate(masks):
41
  mask = mask_data['segmentation']
42
  color = colors[i]
43
  composite_image[mask] = (color[:3] * 255).astype(np.uint8) # Apply color to mask
44
+ print(composite_image.shape, image.shape)
45
 
46
  # Combine original image with the composite mask image
47
+ overlayed_image = (composite_image * 0.5 + torch.from_numpy(image).resize(738, 1200, 3).cpu().numpy() * 0.5).astype(np.uint8)
48
+
49
  return overlayed_image
50
 
51
 
52
 
53
+
54
  iface = gr.Interface(
55
  fn=get_masks,
56
  inputs=["image", gr.components.Dropdown(choices=['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
 
59
  description="Upload an image, select a model type, and receive the segmented and classified parts."
60
  )
61
 
 
62
  iface.launch()