import torch from backend import memory_management, attention from backend.modules.k_prediction import k_prediction_from_diffusers_scheduler class KModel(torch.nn.Module): def __init__(self, model, diffusers_scheduler, k_predictor=None, config=None): super().__init__() self.config = config self.storage_dtype = model.storage_dtype self.computation_dtype = model.computation_dtype print(f'K-Model Created: {dict(storage_dtype=self.storage_dtype, computation_dtype=self.computation_dtype)}') self.diffusion_model = model if k_predictor is None: self.predictor = k_prediction_from_diffusers_scheduler(diffusers_scheduler) else: self.predictor = k_predictor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.predictor.calculate_input(sigma, x) if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) context = c_crossattn dtype = self.computation_dtype xc = xc.to(dtype) t = self.predictor.timestep(t).float() context = context.to(dtype) extra_conds = {} for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: extra = extra.to(dtype) extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.predictor.calculate_denoised(sigma, model_output, x) def memory_required(self, input_shape): area = input_shape[0] * input_shape[2] * input_shape[3] dtype_size = memory_management.dtype_size(self.computation_dtype) if attention.attention_function in [attention.attention_pytorch, attention.attention_xformers]: scaler = 1.28 else: scaler = 1.65 if attention.get_attn_precision() == torch.float32: dtype_size = 4 return scaler * area * dtype_size * 16384