|
import bisect |
|
import torch |
|
import torch.nn.functional as F |
|
import lpips |
|
|
|
perceptual_loss = lpips.LPIPS() |
|
|
|
|
|
def distance(img_a, img_b): |
|
return perceptual_loss(img_a, img_b).item() |
|
|
|
|
|
|
|
class AlphaScheduler: |
|
def __init__(self): |
|
... |
|
|
|
def from_imgs(self, imgs): |
|
self.__num_values = len(imgs) |
|
self.__values = [0] |
|
for i in range(self.__num_values - 1): |
|
dis = distance(imgs[i], imgs[i + 1]) |
|
self.__values.append(dis) |
|
self.__values[i + 1] += self.__values[i] |
|
for i in range(self.__num_values): |
|
self.__values[i] /= self.__values[-1] |
|
|
|
def save(self, filename): |
|
torch.save(torch.tensor(self.__values), filename) |
|
|
|
def load(self, filename): |
|
self.__values = torch.load(filename).tolist() |
|
self.__num_values = len(self.__values) |
|
|
|
def get_x(self, y): |
|
assert y >= 0 and y <= 1 |
|
id = bisect.bisect_left(self.__values, y) |
|
id -= 1 |
|
if id < 0: |
|
id = 0 |
|
yl = self.__values[id] |
|
yr = self.__values[id + 1] |
|
xl = id * (1 / (self.__num_values - 1)) |
|
xr = (id + 1) * (1 / (self.__num_values - 1)) |
|
x = (y - yl) / (yr - yl) * (xr - xl) + xl |
|
return x |
|
|
|
def get_list(self, len=None): |
|
if len is None: |
|
len = self.__num_values |
|
|
|
ys = torch.linspace(0, 1, len) |
|
res = [self.get_x(y) for y in ys] |
|
return res |
|
|