Spaces:
Runtime error
Runtime error
import torch.nn | |
from modules import script_callbacks, shared, devices | |
unet_options = [] | |
current_unet_option = None | |
current_unet = None | |
original_forward = None # not used, only left temporarily for compatibility | |
def list_unets(): | |
new_unets = script_callbacks.list_unets_callback() | |
unet_options.clear() | |
unet_options.extend(new_unets) | |
def get_unet_option(option=None): | |
option = option or shared.opts.sd_unet | |
if option == "None": | |
return None | |
if option == "Automatic": | |
name = shared.sd_model.sd_checkpoint_info.model_name | |
options = [x for x in unet_options if x.model_name == name] | |
option = options[0].label if options else "None" | |
return next(iter([x for x in unet_options if x.label == option]), None) | |
def apply_unet(option=None): | |
global current_unet_option | |
global current_unet | |
new_option = get_unet_option(option) | |
if new_option == current_unet_option: | |
return | |
if current_unet is not None: | |
print(f"Dectivating unet: {current_unet.option.label}") | |
current_unet.deactivate() | |
current_unet_option = new_option | |
if current_unet_option is None: | |
current_unet = None | |
return | |
current_unet = current_unet_option.create_unet() | |
current_unet.option = current_unet_option | |
print(f"Activating unet: {current_unet.option.label}") | |
current_unet.activate() | |
class SdUnetOption: | |
model_name = None | |
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" | |
label = None | |
"""name of the unet in UI""" | |
def create_unet(self): | |
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" | |
raise NotImplementedError() | |
class SdUnet(torch.nn.Module): | |
def forward(self, x, timesteps, context, *args, **kwargs): | |
raise NotImplementedError() | |
def activate(self): | |
pass | |
def deactivate(self): | |
pass | |
def create_unet_forward(original_forward): | |
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): | |
if current_unet is not None: | |
return current_unet.forward(x, timesteps, context, *args, **kwargs) | |
return original_forward(self, x, timesteps, context, *args, **kwargs) | |
return UNetModel_forward | |