|
import torch |
|
from torchvision.transforms.functional import to_pil_image |
|
from segment_anything import SamPredictor, sam_model_registry |
|
from PIL import Image |
|
|
|
class SegmentationModel: |
|
def __init__(self) -> None: |
|
pass |
|
def generate(self, image: torch.Tensor) -> Image.Image: |
|
pass |
|
|
|
class SamSegmentationModel(SegmentationModel): |
|
def __init__( |
|
self, |
|
model_type: str, |
|
checkpoint_path: str, |
|
device = torch.device("cpu"), |
|
) -> None: |
|
super().__init__() |
|
sam = sam_model_registry[model_type](checkpoint=checkpoint_path) |
|
sam.to(device) |
|
self.device = device |
|
self.model = SamPredictor(sam) |
|
|
|
def generate(self, image: torch.Tensor) -> Image.Image: |
|
_, H, W = image.size() |
|
image = image.unsqueeze(0) |
|
self.model.set_torch_image(image, original_image_size=(H, W)) |
|
center_point = [H / 2, W / 2] |
|
input_point = torch.tensor([[center_point]]).to(self.device) |
|
input_label = torch.tensor([[1]]).to(self.device) |
|
masks, scores, logits = self.model.predict_torch( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
boxes=None, |
|
multimask_output=True |
|
) |
|
masks = masks.squeeze(0) |
|
scores = scores.squeeze(0) |
|
bmask = masks[torch.argmax(scores).item()] |
|
mask_float = 1.0 - bmask.float() |
|
final = torch.stack([mask_float, mask_float, mask_float]) |
|
return to_pil_image(final) |