File size: 3,828 Bytes
ab854b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Ultralytics YOLO 🚀, AGPL-3.0 license

from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
import torch

def print_tensor_shapes(item, index_path=None):
    if index_path is None:
        index_path = []  # Initialize the index path for the top-level call

    if isinstance(item, torch.Tensor):
        # Print the index path and shape if the item is a tensor
        print(f"Tensor at Index Path {index_path}: Shape {item.shape}")
    elif isinstance(item, (list, tuple)):
        # Recursively call the function for nested lists and tuples
        for i, sub_item in enumerate(item):
            print_tensor_shapes(sub_item, index_path + [i])
    else:
        # Print the type of the item if it is not a tensor, list, or tuple
        print(f"Item at Index Path {index_path} is not a tensor, list, or tuple. It is a {type(item)}.")


class SegmentationPredictor(DetectionPredictor):
    """
    A class extending the DetectionPredictor class for prediction based on a segmentation model.

    Example:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.segment import SegmentationPredictor

        args = dict(model='yolov8n-seg.pt', source=ASSETS)
        predictor = SegmentationPredictor(overrides=args)
        predictor.predict_cli()
        ```
    """
    

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = 'segment'


    def postprocess(self, preds, img, orig_imgs):
        #print(preds[0].shape)
        regression_preds = preds[1][-1]
        p, final_reg = ops.non_max_suppression(prediction=preds[0],
                                               mask_coef = preds[1][1],
                                               proto = preds[1][-2],
                                               img_shape = img.shape[2:],
                                       conf_thres=self.args.conf,
                                       iou_thres=self.args.iou,
                                       agnostic=self.args.agnostic_nms,
                                       max_det=self.args.max_det,
                                       nc=len(self.model.names),
                                       regression_var=regression_preds,
                                       classes=self.args.classes)
        #print(p[0].shape)
        results = []
        is_list = isinstance(orig_imgs, list)  # input images are a list, not a torch.Tensor
        if len(preds[1])==3:
            proto = preds[1][-1]
        elif len(preds[1])==4:
            proto = preds[1][-2] 
        else:
            proto = preds[1]

        #print(regression_preds.shape)
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i] if is_list else orig_imgs
            img_path = self.batch[0][i]

            
            if not len(pred):  # save empty boxes
                masks = None
            elif self.args.retina_masks:
                if is_list:
                    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
                masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2])  # HWC
            else:
                masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
                if is_list:
                    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            #print(masks.shape)
            #print(final_reg[i].shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks, regression_preds=final_reg[i]))
        return results