Spaces:
Runtime error
Runtime error
import torch | |
import safetensors.torch as sf | |
from backend import utils | |
class ForgeObjects: | |
def __init__(self, unet, clip, vae, clipvision): | |
self.unet = unet | |
self.clip = clip | |
self.vae = vae | |
self.clipvision = clipvision | |
def shallow_copy(self): | |
return ForgeObjects( | |
self.unet, | |
self.clip, | |
self.vae, | |
self.clipvision | |
) | |
class ForgeDiffusionEngine: | |
matched_guesses = [] | |
def __init__(self, estimated_config, huggingface_components): | |
self.model_config = estimated_config | |
self.is_inpaint = estimated_config.inpaint_model() | |
self.forge_objects = None | |
self.forge_objects_original = None | |
self.forge_objects_after_applying_lora = None | |
self.current_lora_hash = str([]) | |
self.fix_for_webui_backward_compatibility() | |
def set_clip_skip(self, clip_skip): | |
pass | |
def get_first_stage_encoding(self, x): | |
return x # legacy code, do not change | |
def get_learned_conditioning(self, prompt: list[str]): | |
pass | |
def encode_first_stage(self, x): | |
pass | |
def decode_first_stage(self, x): | |
pass | |
def get_prompt_lengths_on_ui(self, prompt): | |
return 0, 75 | |
def is_webui_legacy_model(self): | |
return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3 | |
def fix_for_webui_backward_compatibility(self): | |
self.tiling_enabled = False | |
self.first_stage_model = None | |
self.cond_stage_model = None | |
self.use_distilled_cfg_scale = False | |
self.is_sd1 = False | |
self.is_sd2 = False | |
self.is_sdxl = False | |
self.is_sd3 = False | |
return | |
def save_unet(self, filename): | |
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model) | |
sf.save_file(sd, filename) | |
return filename | |
def save_checkpoint(self, filename): | |
sd = {} | |
sd.update( | |
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') | |
) | |
sd.update( | |
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.') | |
) | |
sd.update( | |
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.') | |
) | |
sf.save_file(sd, filename) | |
return filename | |