Spaces:
Paused
Paused
File size: 1,265 Bytes
bfd34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
class InputShape:
def __init__(self, image_size):
self.h,self.w = image_size[::-1]
self.res = self.h * self.w
self.res64 = self.res // 64
self.res32 = self.res // 64 // 4
self.res16 = self.res // 64 // 16
self.res8 = self.res // 64 // 64
self.shape = [self.h, self.w]
self.shape64 = [self.h // 8, self.w // 8]
self.shape32 = [self.h // 16, self.w // 16]
self.shape16 = [self.h // 32, self.w // 32]
self.shape8 = [self.h // 64, self.w // 64]
def reshape(self, x):
assert len(x.shape) == 3
if x.shape[1] == self.res64: return x.reshape([x.shape[0]] + self.shape64 + [x.shape[-1]])
if x.shape[1] == self.res32: return x.reshape([x.shape[0]] + self.shape32 + [x.shape[-1]])
if x.shape[1] == self.res16: return x.reshape([x.shape[0]] + self.shape16 + [x.shape[-1]])
if x.shape[1] == self.res8: return x.reshape([x.shape[0]] + self.shape8 + [x.shape[-1]])
raise Exception("Unknown shape")
def get_res(self, q, device = 'cpu'):
if q.shape[1] == self.res64: return 64
if q.shape[1] == self.res32: return 32
if q.shape[1] == self.res16: return 16
if q.shape[1] == self.res8: return 8 |