File size: 5,673 Bytes
bfd34e9
f1cc496
bfd34e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch 
from src.utils.iimage import IImage

class InputMask:
    def to(self, device): return InputMask(self.image, device = device)
    def cuda(self): return InputMask(self.image, device = 'cuda')
    def cpu(self): return InputMask(self.image, device = 'cpu')

    def __init__(self, input_image, device = 'cpu'):
        '''
        args:
            input_image: (b,c,h,w) tensor
        '''
        if hasattr(input_image, 'is_iimage'):
            self.image = input_image
            self.val512 = self.full = (input_image.torch(0) > 0.5).float()
        elif isinstance(input_image, torch.Tensor):
            self.val512 = self.full = input_image.clone()
            self.image = IImage(input_image,0)
        
        self.h,self.w = h,w = self.val512.shape[-2:]
        self.shape   = [self.h, self.w]
        self.shape64 = [self.h // 8, self.w // 8]
        self.shape32 = [self.h // 16, self.w // 16]
        self.shape16 = [self.h // 32, self.w // 32]
        self.shape8  = [self.h // 64, self.w // 64]

        self.res   = self.h * self.w
        self.res64 = self.res // 64
        self.res32 = self.res // 64 // 4
        self.res16 = self.res // 64 // 16
        self.res8  = self.res // 64 // 64

        self.img = self.image
        self.img512 = self.image
        self.img64 = self.image.resize((h//8,w//8))
        self.img32 = self.image.resize((h//16,w//16))
        self.img16 = self.image.resize((h//32,w//32))
        self.img8  = self.image.resize((h//64,w//64))

        self.val64 = (self.img64.torch(0) > 0.5).float()
        self.val32 = (self.img32.torch(0) > 0.5).float()
        self.val16 = (self.img16.torch(0) > 0.5).float()
        self.val8  = ( self.img8.torch(0) > 0.5).float()

    
    def get_res(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.val64.to(device)
        if q.shape[1] == self.res32: return self.val32.to(device)
        if q.shape[1] == self.res16: return self.val16.to(device)
        if q.shape[1] == self.res8: return self.val8.to(device)

    def get_res(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.val64.to(device)
        if q.shape[1] == self.res32: return self.val32.to(device)
        if q.shape[1] == self.res16: return self.val16.to(device)
        if q.shape[1] == self.res8: return self.val8.to(device)

    def get_shape(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.shape64
        if q.shape[1] == self.res32: return self.shape32
        if q.shape[1] == self.res16: return self.shape16
        if q.shape[1] == self.res8: return self.shape8

    def get_res_val(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return 64
        if q.shape[1] == self.res32: return 32
        if q.shape[1] == self.res16: return 16
        if q.shape[1] == self.res8: return 8


class InputMask2:
    def to(self, device): return InputMask2(self.image, device = device)
    def cuda(self): return InputMask2(self.image, device = 'cuda')
    def cpu(self): return InputMask2(self.image, device = 'cpu')

    def __init__(self, input_image, device = 'cpu'):
        '''
        args:
            input_image: (b,c,h,w) tensor
        '''
        if hasattr(input_image, 'is_iimage'):
            self.image = input_image
            self.val512 = self.full = input_image.torch(0).bool().float()
        elif isinstance(input_image, torch.Tensor):
            self.val512 = self.full = input_image.clone()
            self.image = IImage(input_image,0)
        
        self.h,self.w = h,w = self.val512.shape[-2:]
        self.shape   = [self.h, self.w]
        self.shape64 = [self.h // 8, self.w // 8]
        self.shape32 = [self.h // 16, self.w // 16]
        self.shape16 = [self.h // 32, self.w // 32]
        self.shape8  = [self.h // 64, self.w // 64]

        self.res   = self.h * self.w
        self.res64 = self.res // 64
        self.res32 = self.res // 64 // 4
        self.res16 = self.res // 64 // 16
        self.res8  = self.res // 64 // 64

        self.img = self.image
        self.img512 = self.image
        self.img64 = self.image.resize((h//8,w//8))
        self.img32 = self.image.resize((h//16,w//16))
        self.img16 = self.image.resize((h//32,w//32)).dilate(1)
        self.img8  = self.image.resize((h//64,w//64)).dilate(1)

        self.val64 = self.img64.torch(0).bool().float()
        self.val32 = self.img32.torch(0).bool().float()
        self.val16 = self.img16.torch(0).bool().float()
        self.val8  =  self.img8.torch(0).bool().float()

    
    def get_res(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.val64.to(device)
        if q.shape[1] == self.res32: return self.val32.to(device)
        if q.shape[1] == self.res16: return self.val16.to(device)
        if q.shape[1] == self.res8: return self.val8.to(device)

    def get_res(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.val64.to(device)
        if q.shape[1] == self.res32: return self.val32.to(device)
        if q.shape[1] == self.res16: return self.val16.to(device)
        if q.shape[1] == self.res8: return self.val8.to(device)

    def get_shape(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return self.shape64
        if q.shape[1] == self.res32: return self.shape32
        if q.shape[1] == self.res16: return self.shape16
        if q.shape[1] == self.res8: return self.shape8

    def get_res_val(self, q, device = 'cpu'):
        if q.shape[1] == self.res64: return 64
        if q.shape[1] == self.res32: return 32
        if q.shape[1] == self.res16: return 16
        if q.shape[1] == self.res8: return 8