moistdio's picture
Upload folder using huggingface_hub
6831a54 verified
import torch
from backend import memory_management, attention
from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler
class KModel(torch.nn.Module):
def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None):
super().__init__()
self.config = config
self.storage_dtype = model.storage_dtype
self.computation_dtype = model.computation_dtype
print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}')
self.diffusion_model = model
if k_predictor is None:
self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler)
else:
self.predictor = k_predictor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.predictor.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = self.computation_dtype
xc = xc.to(dtype)
t = self.predictor.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.predictor.calculate_denoised(sigma, model_output, x)
def memory_required(self, input_shape):
area = input_shape[0] * input_shape[2] * input_shape[3]
dtype_size = memory_management.dtype_size(self.computation_dtype)
if attention.attention_function in [attention.attention_pytorch, attention.attention_xformers]:
scaler = 1.28
else:
scaler = 1.65
if attention.get_attn_precision() == torch.float32:
dtype_size = 4
return scaler * area * dtype_size * 16384