File size: 3,397 Bytes
8fa1f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gc

import numpy as np
import torch
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import gradio as gr
import cv2
from demo.mask_utils import *

class SAM_Inference:
    def __init__(self, model_type='vit_b', device='cuda') -> None:
        models = {
            'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
            'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
            'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
        }

        sam = sam_model_registry[model_type](checkpoint=models[model_type])
        sam = sam.to(device)

        self.predictor = SamPredictor(sam)
        self.mask_generator = SamAutomaticMaskGenerator(model=sam)

    def img_select_point(self, original_img: np.ndarray, evt: gr.SelectData):
        img = original_img.copy()
        sel_pix = [(evt.index, 1)]  # append the foreground_point

        masks = self.run_inference(original_img, sel_pix)
        for point, label in sel_pix:
            cv2.circle(img, point, 5, (240, 240, 240), -1, 0)
            cv2.circle(img, point, 5, (30, 144, 255), 2, 0)

        mask = masks[0][0]
        colored_mask = mask_foreground(mask)
        res = img_add_masks(original_img, colored_mask, mask)
        return img, process_mask_to_show(mask), res, mask

    def gen_box_seg(self, inp):
        if inp is None:
            raise gr.Error("Please upload an image first!")
        image = inp['image']
        if len(inp['boxes']) == 0:
            raise gr.Error("Please clear the raw boxes and draw a box first!")
        boxes = inp['boxes'][-1]

        input_box = np.array([boxes[0], boxes[1], boxes[2], boxes[3]]).astype(int)

        masks = self.predict_box(image, input_box)

        mask = masks[0][0]
        colored_mask = mask_foreground(mask)
        res = img_add_masks(image, colored_mask, mask)

        return process_mask_to_show(mask), res, mask
    
    def run_inference(self, input_x, selected_points):
        if len(selected_points) == 0:
            return []

        self.predictor.set_image(input_x)

        points = torch.Tensor(
            [p for p, _ in selected_points]
        ).to(self.predictor.device).unsqueeze(0)

        labels = torch.Tensor(
            [int(l) for _, l in selected_points]
        ).to(self.predictor.device).unsqueeze(0)

        transformed_points = self.predictor.transform.apply_coords_torch(
            points, input_x.shape[:2])

        # predict segmentation according to the boxes
        masks, scores, logits = self.predictor.predict_torch(
            point_coords=transformed_points,
            point_labels=labels,
            multimask_output=False,
        )
        masks = masks.cpu().detach().numpy()

        gc.collect()
        torch.cuda.empty_cache()

        return masks

    def predict_box(self, input_x, input_box):
        self.predictor.set_image(input_x)

        input_boxes = torch.tensor(input_box[None, :], device=self.predictor.device)
        transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2])

        masks, _, _ = self.predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False
        )
        masks = masks.cpu().detach().numpy()

        gc.collect()
        torch.cuda.empty_cache()
        return masks