Spaces:
Paused
Paused
File size: 5,673 Bytes
bfd34e9 f1cc496 bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from src.utils.iimage import IImage
class InputMask:
def to(self, device): return InputMask(self.image, device = device)
def cuda(self): return InputMask(self.image, device = 'cuda')
def cpu(self): return InputMask(self.image, device = 'cpu')
def __init__(self, input_image, device = 'cpu'):
'''
args:
input_image: (b,c,h,w) tensor
'''
if hasattr(input_image, 'is_iimage'):
self.image = input_image
self.val512 = self.full = (input_image.torch(0) > 0.5).float()
elif isinstance(input_image, torch.Tensor):
self.val512 = self.full = input_image.clone()
self.image = IImage(input_image,0)
self.h,self.w = h,w = self.val512.shape[-2:]
self.shape = [self.h, self.w]
self.shape64 = [self.h // 8, self.w // 8]
self.shape32 = [self.h // 16, self.w // 16]
self.shape16 = [self.h // 32, self.w // 32]
self.shape8 = [self.h // 64, self.w // 64]
self.res = self.h * self.w
self.res64 = self.res // 64
self.res32 = self.res // 64 // 4
self.res16 = self.res // 64 // 16
self.res8 = self.res // 64 // 64
self.img = self.image
self.img512 = self.image
self.img64 = self.image.resize((h//8,w//8))
self.img32 = self.image.resize((h//16,w//16))
self.img16 = self.image.resize((h//32,w//32))
self.img8 = self.image.resize((h//64,w//64))
self.val64 = (self.img64.torch(0) > 0.5).float()
self.val32 = (self.img32.torch(0) > 0.5).float()
self.val16 = (self.img16.torch(0) > 0.5).float()
self.val8 = ( self.img8.torch(0) > 0.5).float()
def get_res(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.val64.to(device)
if q.shape[1] == self.res32: return self.val32.to(device)
if q.shape[1] == self.res16: return self.val16.to(device)
if q.shape[1] == self.res8: return self.val8.to(device)
def get_res(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.val64.to(device)
if q.shape[1] == self.res32: return self.val32.to(device)
if q.shape[1] == self.res16: return self.val16.to(device)
if q.shape[1] == self.res8: return self.val8.to(device)
def get_shape(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.shape64
if q.shape[1] == self.res32: return self.shape32
if q.shape[1] == self.res16: return self.shape16
if q.shape[1] == self.res8: return self.shape8
def get_res_val(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return 64
if q.shape[1] == self.res32: return 32
if q.shape[1] == self.res16: return 16
if q.shape[1] == self.res8: return 8
class InputMask2:
def to(self, device): return InputMask2(self.image, device = device)
def cuda(self): return InputMask2(self.image, device = 'cuda')
def cpu(self): return InputMask2(self.image, device = 'cpu')
def __init__(self, input_image, device = 'cpu'):
'''
args:
input_image: (b,c,h,w) tensor
'''
if hasattr(input_image, 'is_iimage'):
self.image = input_image
self.val512 = self.full = input_image.torch(0).bool().float()
elif isinstance(input_image, torch.Tensor):
self.val512 = self.full = input_image.clone()
self.image = IImage(input_image,0)
self.h,self.w = h,w = self.val512.shape[-2:]
self.shape = [self.h, self.w]
self.shape64 = [self.h // 8, self.w // 8]
self.shape32 = [self.h // 16, self.w // 16]
self.shape16 = [self.h // 32, self.w // 32]
self.shape8 = [self.h // 64, self.w // 64]
self.res = self.h * self.w
self.res64 = self.res // 64
self.res32 = self.res // 64 // 4
self.res16 = self.res // 64 // 16
self.res8 = self.res // 64 // 64
self.img = self.image
self.img512 = self.image
self.img64 = self.image.resize((h//8,w//8))
self.img32 = self.image.resize((h//16,w//16))
self.img16 = self.image.resize((h//32,w//32)).dilate(1)
self.img8 = self.image.resize((h//64,w//64)).dilate(1)
self.val64 = self.img64.torch(0).bool().float()
self.val32 = self.img32.torch(0).bool().float()
self.val16 = self.img16.torch(0).bool().float()
self.val8 = self.img8.torch(0).bool().float()
def get_res(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.val64.to(device)
if q.shape[1] == self.res32: return self.val32.to(device)
if q.shape[1] == self.res16: return self.val16.to(device)
if q.shape[1] == self.res8: return self.val8.to(device)
def get_res(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.val64.to(device)
if q.shape[1] == self.res32: return self.val32.to(device)
if q.shape[1] == self.res16: return self.val16.to(device)
if q.shape[1] == self.res8: return self.val8.to(device)
def get_shape(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return self.shape64
if q.shape[1] == self.res32: return self.shape32
if q.shape[1] == self.res16: return self.shape16
if q.shape[1] == self.res8: return self.shape8
def get_res_val(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return 64
if q.shape[1] == self.res32: return 32
if q.shape[1] == self.res16: return 16
if q.shape[1] == self.res8: return 8 |