File size: 7,557 Bytes
f6228f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""

SAM model interface.



This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image

segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,

and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new

image distributions and tasks without prior knowledge.



Key Features:

    - Promptable segmentation

    - Real-time performance

    - Zero-shot transfer capabilities

    - Trained on SA-1B dataset

"""

from pathlib import Path

from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info

from .build import build_sam
from .predict import Predictor, SAM2Predictor


class SAM(Model):
    """

    SAM (Segment Anything Model) interface class for real-time image segmentation tasks.



    This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for

    promptable segmentation with versatility in image analysis. It supports various prompts such as bounding

    boxes, points, or labels, and features zero-shot performance capabilities.



    Attributes:

        model (torch.nn.Module): The loaded SAM model.

        is_sam2 (bool): Indicates whether the model is SAM2 variant.

        task (str): The task type, set to "segment" for SAM models.



    Methods:

        predict: Performs segmentation prediction on the given image or video source.

        info: Logs information about the SAM model.



    Examples:

        >>> sam = SAM("sam_b.pt")

        >>> results = sam.predict("image.jpg", points=[[500, 375]])

        >>> for r in results:

        >>>     print(f"Detected {len(r.masks)} masks")

    """

    def __init__(self, model="sam_b.pt") -> None:
        """

        Initializes the SAM (Segment Anything Model) instance.



        Args:

            model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.



        Raises:

            NotImplementedError: If the model file extension is not .pt or .pth.



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> print(sam.is_sam2)

        """
        if model and Path(model).suffix not in {".pt", ".pth"}:
            raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
        self.is_sam2 = "sam2" in Path(model).stem
        super().__init__(model=model, task="segment")

    def _load(self, weights: str, task=None):
        """

        Loads the specified weights into the SAM model.



        This method initializes the SAM model with the provided weights file, setting up the model architecture

        and loading the pre-trained parameters.



        Args:

            weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.

            task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> sam._load("path/to/custom_weights.pt")

        """
        self.model = build_sam(weights)

    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """

        Performs segmentation prediction on the given image or video source.



        Args:

            source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or

                a numpy.ndarray object.

            stream (bool): If True, enables real-time streaming.

            bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.

            points (List[List[float]] | None): List of points for prompted segmentation.

            labels (List[int] | None): List of labels for prompted segmentation.

            **kwargs (Any): Additional keyword arguments for prediction.



        Returns:

            (List): The model predictions.



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> results = sam.predict("image.jpg", points=[[500, 375]])

            >>> for r in results:

            ...     print(f"Detected {len(r.masks)} masks")

        """
        overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
        kwargs = {**overrides, **kwargs}
        prompts = dict(bboxes=bboxes, points=points, labels=labels)
        return super().predict(source, stream, prompts=prompts, **kwargs)

    def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """

        Performs segmentation prediction on the given image or video source.



        This method is an alias for the 'predict' method, providing a convenient way to call the SAM model

        for segmentation tasks.



        Args:

            source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image

                object, or a numpy.ndarray object.

            stream (bool): If True, enables real-time streaming.

            bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.

            points (List[List[float]] | None): List of points for prompted segmentation.

            labels (List[int] | None): List of labels for prompted segmentation.

            **kwargs (Any): Additional keyword arguments to be passed to the predict method.



        Returns:

            (List): The model predictions, typically containing segmentation masks and other relevant information.



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> results = sam("image.jpg", points=[[500, 375]])

            >>> print(f"Detected {len(results[0].masks)} masks")

        """
        return self.predict(source, stream, bboxes, points, labels, **kwargs)

    def info(self, detailed=False, verbose=True):
        """

        Logs information about the SAM model.



        This method provides details about the Segment Anything Model (SAM), including its architecture,

        parameters, and computational requirements.



        Args:

            detailed (bool): If True, displays detailed information about the model layers and operations.

            verbose (bool): If True, prints the information to the console.



        Returns:

            (Tuple): A tuple containing the model's information (string representations of the model).



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> info = sam.info()

            >>> print(info[0])  # Print summary information

        """
        return model_info(self.model, detailed=detailed, verbose=verbose)

    @property
    def task_map(self):
        """

        Provides a mapping from the 'segment' task to its corresponding 'Predictor'.



        Returns:

            (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor

                class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.



        Examples:

            >>> sam = SAM("sam_b.pt")

            >>> task_map = sam.task_map

            >>> print(task_map)

            {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}

        """
        return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}