scdrand23's picture
not working version
814a594
import sys
import random
import torch
import torch.nn as nn
from .point import Point
from .polygon import Polygon
from .scribble import Scribble
from .circle import Circle
from modeling.utils import configurable
class ShapeSampler(nn.Module):
@configurable
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
super().__init__()
self.max_candidate = max_candidate
self.shape_prob = shape_prob
self.shape_candidate = shape_candidate
self.is_train = is_train
@classmethod
def from_config(cls, cfg, is_train=True, mode=None):
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
if mode == 'hack_train':
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
else:
# overwrite condidate_prob
if not is_train:
candidate_probs = [0.0 for x in range(len(candidate_names))]
candidate_probs[candidate_names.index(mode)] = 1.0
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
# Build augmentation
return {
"max_candidate": max_candidate,
"shape_prob": candidate_probs,
"shape_candidate": candidate_classes,
"is_train": is_train,
}
def forward(self, instances):
masks = instances.gt_masks.tensor
boxes = instances.gt_boxes.tensor
if len(masks) == 0:
gt_masks = torch.zeros(masks.shape[-2:]).bool()
rand_masks = torch.zeros(masks.shape[-2:]).bool()
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
indices = [x for x in range(len(masks))]
if self.is_train:
random.shuffle(indices)
candidate_mask = masks[indices[:self.max_candidate]]
candidate_box = boxes[indices[:self.max_candidate]]
else:
candidate_mask = masks
candidate_box = boxes
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
types = [repr(x) for x in draw_funcs]
for i in range(0, len(rand_shapes)):
if rand_shapes[i].sum() == 0:
candidate_mask[i] = candidate_mask[i] * 0
types[i] = 'none'
# candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
def build_shape_sampler(cfg, **kwargs):
return ShapeSampler(cfg, **kwargs)