moistdio's picture
Upload folder using huggingface_hub
6831a54 verified
import torch
import math
def repeat_to_batch_size(tensor, batch_size):
if tensor.shape[0] > batch_size:
return tensor[:batch_size]
elif tensor.shape[0] < batch_size:
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def lcm(a, b):
return abs(a * b) // math.gcd(a, b)
class Condition:
def __init__(self, cond):
self.cond = cond
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class ConditionNoiseShape(Condition):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]]
return self._copy_with(repeat_to_batch_size(data, batch_size).to(device))
class ConditionCrossAttn(Condition):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]:
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4:
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1)
out.append(c)
return torch.cat(out)
class ConditionConstant(Condition):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond
def compile_conditions(cond):
if cond is None:
return None
if isinstance(cond, torch.Tensor):
result = dict(
cross_attn=cond,
model_conds=dict(
c_crossattn=ConditionCrossAttn(cond),
)
)
return [result, ]
cross_attn = cond['crossattn']
pooled_output = cond['vector']
result = dict(
cross_attn=cross_attn,
pooled_output=pooled_output,
model_conds=dict(
c_crossattn=ConditionCrossAttn(cross_attn),
y=Condition(pooled_output)
)
)
if 'guidance' in cond:
result['model_conds']['guidance'] = Condition(cond['guidance'])
return [result, ]
def compile_weighted_conditions(cond, weights):
transposed = list(map(list, zip(*weights)))
results = []
for cond_pre in transposed:
current_indices = []
current_weight = 0
for i, w in cond_pre:
current_indices.append(i)
current_weight = w
if hasattr(cond, 'advanced_indexing'):
feed = cond.advanced_indexing(current_indices)
else:
feed = cond[current_indices]
h = compile_conditions(feed)
h[0]['strength'] = current_weight
results += h
return results