|
import sys |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .point import Point |
|
from .polygon import Polygon |
|
from .scribble import Scribble |
|
from .circle import Circle |
|
|
|
from modeling.utils import configurable |
|
|
|
|
|
class ShapeSampler(nn.Module): |
|
@configurable |
|
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True): |
|
super().__init__() |
|
self.max_candidate = max_candidate |
|
self.shape_prob = shape_prob |
|
self.shape_candidate = shape_candidate |
|
self.is_train = is_train |
|
|
|
@classmethod |
|
def from_config(cls, cfg, is_train=True, mode=None): |
|
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE'] |
|
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS'] |
|
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES'] |
|
|
|
if mode == 'hack_train': |
|
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names] |
|
else: |
|
|
|
if not is_train: |
|
candidate_probs = [0.0 for x in range(len(candidate_names))] |
|
candidate_probs[candidate_names.index(mode)] = 1.0 |
|
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names] |
|
|
|
|
|
return { |
|
"max_candidate": max_candidate, |
|
"shape_prob": candidate_probs, |
|
"shape_candidate": candidate_classes, |
|
"is_train": is_train, |
|
} |
|
|
|
def forward(self, instances): |
|
masks = instances.gt_masks.tensor |
|
boxes = instances.gt_boxes.tensor |
|
|
|
if len(masks) == 0: |
|
gt_masks = torch.zeros(masks.shape[-2:]).bool() |
|
rand_masks = torch.zeros(masks.shape[-2:]).bool() |
|
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']} |
|
indices = [x for x in range(len(masks))] |
|
|
|
if self.is_train: |
|
random.shuffle(indices) |
|
candidate_mask = masks[indices[:self.max_candidate]] |
|
candidate_box = boxes[indices[:self.max_candidate]] |
|
else: |
|
candidate_mask = masks |
|
candidate_box = boxes |
|
|
|
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask)) |
|
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)] |
|
types = [repr(x) for x in draw_funcs] |
|
for i in range(0, len(rand_shapes)): |
|
if rand_shapes[i].sum() == 0: |
|
candidate_mask[i] = candidate_mask[i] * 0 |
|
types[i] = 'none' |
|
|
|
|
|
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self} |
|
|
|
def build_shape_sampler(cfg, **kwargs): |
|
return ShapeSampler(cfg, **kwargs) |