toto10's picture
Upload folder using huggingface_hub (#1)
import torch
import einops
import hashlib
import numpy as np
import torch.nn as nn
from enum import Enum
from modules import devices, lowvram, shared, scripts
cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)
from ldm.modules.diffusionmodules.util import timestep_embedding
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from ldm.modules.attention import BasicTransformerBlock
from ldm.models.diffusion.ddpm import extract_into_tensor
from modules.prompt_parser import MulticondLearnedConditioning, ComposableScheduledPromptConditioning, ScheduledPromptConditioning
MARK_EPS = 1e-3
def prompt_context_is_marked(x):
t = x[..., 0, :]
m = torch.abs(t) - POSITIVE_MARK_TOKEN
m = torch.mean(torch.abs(m)).detach().cpu().float().numpy()
return float(m) < MARK_EPS
def mark_prompt_context(x, positive):
if isinstance(x, list):
for i in range(len(x)):
x[i] = mark_prompt_context(x[i], positive)
return x
if isinstance(x, MulticondLearnedConditioning):
x.batch = mark_prompt_context(x.batch, positive)
return x
if isinstance(x, ComposableScheduledPromptConditioning):
x.schedules = mark_prompt_context(x.schedules, positive)
return x
if isinstance(x, ScheduledPromptConditioning):
cond = x.cond
if prompt_context_is_marked(cond):
return x
cond =[torch.zeros_like(cond)[:1] + mark, cond], dim=0)
return ScheduledPromptConditioning(end_at_step=x.end_at_step, cond=cond)
return x
disable_controlnet_prompt_warning = True
# You can disable this warning using disable_controlnet_prompt_warning.
def unmark_prompt_context(x):
if not prompt_context_is_marked(x):
# ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
# You can use the's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
# Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
# if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
# if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
# After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
# After you mark the prompts, the mismatch errors will disappear.
if not disable_controlnet_prompt_warning:
print('ControlNet Error: Failed to detect whether an instance is cond or uncond!')
print('ControlNet Error: This is mainly because other extension(s) blocked A1111\'s \"process.sample()\" and deleted ControlNet\'s sample function.')
print('ControlNet Error: ControlNet will shift to a backup backend but the results will be worse than expectation.')
print('Solution (For extension developers): Take a look at ControlNet\' '
'UnetHook.hook.process_sample and manually call mark_prompt_context to mark cond/uncond prompts.')
mark_batch = torch.ones(size=(x.shape[0], 1, 1, 1), dtype=x.dtype, device=x.device)
uc_indices = []
context = x
return mark_batch, uc_indices, context
mark = x[:, 0, :]
context = x[:, 1:, :]
mark = torch.mean(torch.abs(mark - NEGATIVE_MARK_TOKEN), dim=1)
mark = (mark > MARK_EPS).float()
mark_batch = mark[:, None, None, None].to(x.dtype).to(x.device)
uc_indices = mark.detach().cpu().numpy().tolist()
uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]
return mark_batch, uc_indices, context
class ControlModelType(Enum):
The type of Control Models (supported or not).
ControlNet = "ControlNet, Lvmin Zhang"
T2I_Adapter = "T2I_Adapter, Chong Mou"
T2I_StyleAdapter = "T2I_StyleAdapter, Chong Mou"
T2I_CoAdapter = "T2I_CoAdapter, Chong Mou"
MasaCtrl = "MasaCtrl, Mingdeng Cao"
GLIGEN = "GLIGEN, Yuheng Li"
AttentionInjection = "AttentionInjection, Lvmin Zhang" # A simple attention injection written by Lvmin
StableSR = "StableSR, Jianyi Wang"
PromptDiffusion = "PromptDiffusion, Zhendong Wang"
ControlLoRA = "ControlLoRA, Wu Hecong"
# Written by Lvmin
class AutoMachine(Enum):
Lvmin's algorithm for Attention/AdaIn AutoMachine States.
Read = "Read"
Write = "Write"
class TorchHijackForUnet:
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
def __getattr__(self, item):
if item == 'cat':
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
a, b = tensors
if a.shape[-2:] != b.shape[-2:]:
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
tensors = (a, b)
return, *args, **kwargs)
th = TorchHijackForUnet()
class ControlParams:
def __init__(
**kwargs # To avoid errors
self.control_model = control_model
self.preprocessor = preprocessor
self._hint_cond = hint_cond
self.weight = weight
self.guidance_stopped = guidance_stopped
self.start_guidance_percent = start_guidance_percent
self.stop_guidance_percent = stop_guidance_percent
self.advanced_weighting = advanced_weighting
self.control_model_type = control_model_type
self.global_average_pooling = global_average_pooling
self.hr_hint_cond = hr_hint_cond
self.used_hint_cond = None
self.used_hint_cond_latent = None
self.used_hint_inpaint_hijack = None
self.soft_injection = soft_injection
self.cfg_injection = cfg_injection
def hint_cond(self):
return self._hint_cond
# fix for all the extensions that modify hint_cond,
# by forcing used_hint_cond to update on the next timestep
# hr_hint_cond can stay the same, since most extensions dont modify the hires pass
# but if they do, it will cause problems
def hint_cond(self, new_hint_cond):
self._hint_cond = new_hint_cond
self.used_hint_cond = None
self.used_hint_cond_latent = None
self.used_hint_inpaint_hijack = None
def aligned_adding(base, x, require_channel_alignment):
if isinstance(x, float):
if x == 0.0:
return base
return base + x
if require_channel_alignment:
zeros = torch.zeros_like(base)
zeros[:, :x.shape[1], ...] = x
x = zeros
# resize to sample resolution
base_h, base_w = base.shape[-2:]
xh, xw = x.shape[-2:]
if base_h != xh or base_w != xw:
print('[Warning] ControlNet finds unexpected mis-alignment in tensor shape.')
x = th.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest")
return base + x
# DFS Search for Torch.nn.Module, Written by Lvmin
def torch_dfs(model: torch.nn.Module):
result = [model]
for child in model.children():
result += torch_dfs(child)
return result
def predict_start_from_noise(ldm, x_t, t, noise):
return extract_into_tensor(ldm.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(ldm.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
def predict_noise_from_start(ldm, x_t, t, x0):
return (extract_into_tensor(ldm.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract_into_tensor(ldm.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def blur(x, k):
y = torch.nn.functional.pad(x, (k, k, k, k), mode='replicate')
y = torch.nn.functional.avg_pool2d(y, (k*2+1, k*2+1), stride=(1, 1))
return y
class TorchCache:
def __init__(self):
self.cache = {}
def hash(self, key):
v = key.detach().cpu().numpy().astype(np.float32)
v = (v * 1000.0).astype(np.int32)
v = np.ascontiguousarray(v.copy())
sha = hashlib.sha1(v).hexdigest()
return sha
def get(self, key):
key = self.hash(key)
return self.cache.get(key, None)
def set(self, key, value):
self.cache[self.hash(key)] = value
class UnetHook(nn.Module):
def __init__(self, lowvram=False) -> None:
self.lowvram = lowvram
self.model = None
self.sd_ldm = None
self.control_params = None
self.attention_auto_machine = AutoMachine.Read
self.attention_auto_machine_weight = 1.0
self.gn_auto_machine = AutoMachine.Read
self.gn_auto_machine_weight = 1.0
self.current_style_fidelity = 0.0
self.current_uc_indices = None
def guidance_schedule_handler(self, x):
for param in self.control_params:
current_sampling_percent = (x.sampling_step / x.total_sampling_steps)
param.guidance_stopped = current_sampling_percent < param.start_guidance_percent or current_sampling_percent > param.stop_guidance_percent
def hook(self, model, sd_ldm, control_params, process):
self.model = model
self.sd_ldm = sd_ldm
self.control_params = control_params
outer = self
def process_sample(*args, **kwargs):
# ControlNet must know whether a prompt is conditional prompt (positive prompt) or unconditional conditioning prompt (negative prompt).
# You can use the's `mark_prompt_context` to mark the prompts that will be seen by ControlNet.
# Let us say XXX is a MulticondLearnedConditioning or a ComposableScheduledPromptConditioning or a ScheduledPromptConditioning or a list of these components,
# if XXX is a positive prompt, you should call mark_prompt_context(XXX, positive=True)
# if XXX is a negative prompt, you should call mark_prompt_context(XXX, positive=False)
# After you mark the prompts, the ControlNet will know which prompt is cond/uncond and works as expected.
# After you mark the prompts, the mismatch errors will disappear.
mark_prompt_context(kwargs.get('conditioning', []), positive=True)
mark_prompt_context(kwargs.get('unconditional_conditioning', []), positive=False)
mark_prompt_context(getattr(process, 'hr_c', []), positive=True)
mark_prompt_context(getattr(process, 'hr_uc', []), positive=False)
return process.sample_before_CN_hack(*args, **kwargs)
def vae_forward(x, batch_size, mask=None):
if x.shape[1] > 3:
x = x[:, 0:3, :, :]
x = x * 2.0 - 1.0
if mask is not None:
x = x * (1.0 - mask)
x = x.type(devices.dtype_vae)
vae_output = outer.vae_cache.get(x)
if vae_output is None:
with devices.autocast():
vae_output = outer.sd_ldm.encode_first_stage(x)
vae_output = outer.sd_ldm.get_first_stage_encoding(vae_output)
outer.vae_cache.set(x, vae_output)
print(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
latent = vae_output
if latent.shape[0] != batch_size:
latent =[latent.clone() for _ in range(batch_size)], dim=0)
latent = latent.type(devices.dtype_unet)
return latent
except Exception as e:
raise ValueError('ControlNet failed to use VAE. Please try to add `--no-half-vae`, `--no-half` and remove `--precision full` in launch cmd.')
def forward(self, x, timesteps=None, context=None, **kwargs):
total_controlnet_embedding = [0.0] * 13
total_t2i_adapter_embedding = [0.0] * 4
require_inpaint_hijack = False
is_in_high_res_fix = False
batch_size = int(x.shape[0])
# Handle cond-uncond marker
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)
# print(str(cond_mark[:, 0, 0, 0].detach().cpu().numpy().tolist()) + ' - ' + str(outer.current_uc_indices))
# High-res fix
for param in outer.control_params:
# select which hint_cond to use
if param.used_hint_cond is None:
param.used_hint_cond = param.hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
# has high-res fix
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 4 and param.hr_hint_cond.ndim == 4:
_, _, h_lr, w_lr = param.hint_cond.shape
_, _, h_hr, w_hr = param.hr_hint_cond.shape
_, _, h, w = x.shape
h, w = h * 8, w * 8
if abs(h - h_lr) < abs(h - h_hr):
is_in_high_res_fix = False
if param.used_hint_cond is not param.hint_cond:
param.used_hint_cond = param.hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
is_in_high_res_fix = True
if param.used_hint_cond is not param.hr_hint_cond:
param.used_hint_cond = param.hr_hint_cond
param.used_hint_cond_latent = None
param.used_hint_inpaint_hijack = None
# Convert control image to latent
for param in outer.control_params:
if param.used_hint_cond_latent is not None:
if param.control_model_type not in [ControlModelType.AttentionInjection] \
and 'colorfix' not in param.preprocessor['name'] \
and 'inpaint_only' not in param.preprocessor['name']:
param.used_hint_cond_latent = vae_forward(param.used_hint_cond, batch_size=batch_size)
# handle prompt token control
for param in outer.control_params:
if param.guidance_stopped:
if param.control_model_type not in [ControlModelType.T2I_StyleAdapter]:
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
control =[control.clone() for _ in range(batch_size)], dim=0)
control *= param.weight
control *= cond_mark[:, :, :, 0]
context =[context, control.clone()], dim=1)
# handle ControlNet / T2I_Adapter
for param in outer.control_params:
if param.guidance_stopped:
if param.control_model_type not in [ControlModelType.ControlNet, ControlModelType.T2I_Adapter]:
# inpaint model workaround
x_in = x
control_model = param.control_model.control_model
if param.control_model_type == ControlModelType.ControlNet:
if x.shape[1] != control_model.input_blocks[0][0].in_channels and x.shape[1] == 9:
# inpaint_model: 4 data + 4 downscaled image + 1 mask
x_in = x[:, :4, ...]
require_inpaint_hijack = True
assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"
hint = param.used_hint_cond
# ControlNet inpaint protocol
if hint.shape[1] == 4:
c = hint[:, 0:3, :, :]
m = hint[:, 3:4, :, :]
m = (m > 0.5).float()
hint = c * (1 - m) - m
control = param.control_model(x=x_in, hint=hint, timesteps=timesteps, context=context)
control_scales = ([param.weight] * 13)
if outer.lowvram:"cpu")
if param.cfg_injection or param.global_average_pooling:
if param.control_model_type == ControlModelType.T2I_Adapter:
control = [[c.clone() for _ in range(batch_size)], dim=0) for c in control]
control = [c * cond_mark for c in control]
if param.soft_injection or is_in_high_res_fix:
# important! use the soft weights with high-res fix can significantly reduce artifacts.
if param.control_model_type == ControlModelType.T2I_Adapter:
control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)]
elif param.control_model_type == ControlModelType.ControlNet:
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]
if param.advanced_weighting is not None:
control_scales = param.advanced_weighting
control = [c * scale for c, scale in zip(control, control_scales)]
if param.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
for idx, item in enumerate(control):
target = None
if param.control_model_type == ControlModelType.ControlNet:
target = total_controlnet_embedding
if param.control_model_type == ControlModelType.T2I_Adapter:
target = total_t2i_adapter_embedding
if target is not None:
target[idx] = item + target[idx]
# Clear attention and AdaIn cache
for module in outer.attn_module_list: = []
module.style_cfgs = []
for module in outer.gn_module_list:
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
# Handle attention and AdaIn control
for param in outer.control_params:
if param.guidance_stopped:
if param.used_hint_cond_latent is None:
if param.control_model_type not in [ControlModelType.AttentionInjection]:
ref_xt = outer.sd_ldm.q_sample(param.used_hint_cond_latent, torch.round(timesteps.float()).long())
# Inpaint Hijack
if x.shape[1] == 9:
ref_xt =[
torch.zeros_like(ref_xt)[:, 0:1, :, :],
], dim=1)
outer.current_style_fidelity = float(param.preprocessor['threshold_a'])
outer.current_style_fidelity = max(0.0, min(1.0, outer.current_style_fidelity))
if param.cfg_injection:
outer.current_style_fidelity = 1.0
elif param.soft_injection or is_in_high_res_fix:
outer.current_style_fidelity = 0.0
control_name = param.preprocessor['name']
if control_name in ['reference_only', 'reference_adain+attn']:
outer.attention_auto_machine = AutoMachine.Write
outer.attention_auto_machine_weight = param.weight
if control_name in ['reference_adain', 'reference_adain+attn']:
outer.gn_auto_machine = AutoMachine.Write
outer.gn_auto_machine_weight = param.weight
outer.attention_auto_machine = AutoMachine.Read
outer.gn_auto_machine = AutoMachine.Read
# Replace x_t to support inpaint models
for param in outer.control_params:
if param.used_hint_cond.shape[1] != 4:
if x.shape[1] != 9:
if param.used_hint_inpaint_hijack is None:
mask_pixel = param.used_hint_cond[:, 3:4, :, :]
image_pixel = param.used_hint_cond[:, 0:3, :, :]
mask_pixel = (mask_pixel > 0.5).to(mask_pixel.dtype)
masked_latent = vae_forward(image_pixel, batch_size, mask=mask_pixel)
mask_latent = torch.nn.functional.max_pool2d(mask_pixel, (8, 8))
if mask_latent.shape[0] != batch_size:
mask_latent =[mask_latent.clone() for _ in range(batch_size)], dim=0)
param.used_hint_inpaint_hijack =[mask_latent, masked_latent], dim=1)
x =[x[:, :4, :, :], param.used_hint_inpaint_hijack], dim=1)
# A1111 fix for medvram.
if shared.cmd_opts.medvram:
# Trigger the register_forward_pre_hook
# U-Net Encoder
hs = []
with th.no_grad():
t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for i, module in enumerate(self.input_blocks):
h = module(h, emb, context)
if (i + 1) % 3 == 0:
h = aligned_adding(h, total_t2i_adapter_embedding.pop(0), require_inpaint_hijack)
h = self.middle_block(h, emb, context)
# U-Net Middle Block
h = aligned_adding(h, total_controlnet_embedding.pop(), require_inpaint_hijack)
# U-Net Decoder
for i, module in enumerate(self.output_blocks):
h =[h, aligned_adding(hs.pop(), total_controlnet_embedding.pop(), require_inpaint_hijack)], dim=1)
h = module(h, emb, context)
# U-Net Output
h = h.type(x.dtype)
h = self.out(h)
# Post-processing for color fix
for param in outer.control_params:
if param.used_hint_cond_latent is None:
if 'colorfix' not in param.preprocessor['name']:
k = int(param.preprocessor['threshold_a'])
if is_in_high_res_fix:
k *= 2
# Inpaint hijack
xt = x[:, :4, :, :]
x0_origin = param.used_hint_cond_latent
t = torch.round(timesteps.float()).long()
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
x0 = x0_prd - blur(x0_prd, k) + blur(x0_origin, k)
if '+sharp' in param.preprocessor['name']:
detail_weight = float(param.preprocessor['threshold_b']) * 0.01
neg = detail_weight * blur(x0, k) + (1 - detail_weight) * x0
x0 = cond_mark * x0 + (1 - cond_mark) * neg
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
w = max(0.0, min(1.0, float(param.weight)))
h = eps_prd * w + h * (1 - w)
# Post-processing for restore
for param in outer.control_params:
if param.used_hint_cond_latent is None:
if 'inpaint_only' not in param.preprocessor['name']:
if param.used_hint_cond.shape[1] != 4:
# Inpaint hijack
xt = x[:, :4, :, :]
mask = param.used_hint_cond[:, 3:4, :, :]
mask = torch.nn.functional.max_pool2d(mask, (10, 10), stride=(8, 8), padding=1)
x0_origin = param.used_hint_cond_latent
t = torch.round(timesteps.float()).long()
x0_prd = predict_start_from_noise(outer.sd_ldm, xt, t, h)
x0 = x0_prd * mask + x0_origin * (1 - mask)
eps_prd = predict_noise_from_start(outer.sd_ldm, xt, t, x0)
w = max(0.0, min(1.0, float(param.weight)))
h = eps_prd * w + h * (1 - w)
return h
def forward_webui(*args, **kwargs):
# webui will handle other compoments
if shared.cmd_opts.lowvram:
return forward(*args, **kwargs)
if self.lowvram:
for param in self.control_params:
if isinstance(param.control_model, torch.nn.Module):"cpu")
def hacked_basic_transformer_inner_forward(self, x, context=None):
x_norm1 = self.norm1(x)
self_attn1 = None
if self.disable_self_attn:
# Do not use self-attention
self_attn1 = self.attn1(x_norm1, context=context)
# Use self-attention
self_attention_context = x_norm1
if outer.attention_auto_machine == AutoMachine.Write:
if outer.attention_auto_machine_weight > self.attn_weight:
if outer.attention_auto_machine == AutoMachine.Read:
if len( > 0:
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
self_attn1_uc = self.attn1(x_norm1,[self_attention_context] +, dim=1))
self_attn1_c = self_attn1_uc.clone()
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
self_attn1_c[outer.current_uc_indices] = self.attn1(
self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc = []
self.style_cfgs = []
if self_attn1 is None:
self_attn1 = self.attn1(x_norm1, context=self_attention_context)
x = + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def hacked_group_norm_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
y = None
if outer.gn_auto_machine == AutoMachine.Write:
if outer.gn_auto_machine_weight > self.gn_weight:
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
if outer.gn_auto_machine == AutoMachine.Read:
if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs))
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
var_acc = sum(self.var_bank) / float(len(self.var_bank))
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
y_uc = (((x - mean) / std) * std_acc) + mean_acc
y_c = y_uc.clone()
if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5:
y_c[outer.current_uc_indices] =[outer.current_uc_indices]
y = style_cfg * y_c + (1.0 - style_cfg) * y_uc
self.mean_bank = []
self.var_bank = []
self.style_cfgs = []
if y is None:
y = x
if getattr(process, 'sample_before_CN_hack', None) is None:
process.sample_before_CN_hack = process.sample
process.sample = process_sample
model._original_forward = model.forward
outer.original_forward = model.forward
model.forward = forward_webui.__get__(model, UNetModel)
outer.vae_cache = TorchCache()
all_modules = torch_dfs(model)
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: - x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
if getattr(module, '_original_inner_forward', None) is None:
module._original_inner_forward = module._forward
module._forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) = []
module.style_cfgs = []
module.attn_weight = float(i) / float(len(attn_modules))
gn_modules = [model.middle_block]
model.middle_block.gn_weight = 0
input_block_indices = [4, 5, 7, 8, 10, 11]
for w, i in enumerate(input_block_indices):
module = model.input_blocks[i]
module.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
for w, i in enumerate(output_block_indices):
module = model.output_blocks[i]
module.gn_weight = float(w) / float(len(output_block_indices))
for i, module in enumerate(gn_modules):
if getattr(module, 'original_forward', None) is None:
module.original_forward = module.forward
module.forward = hacked_group_norm_forward.__get__(module, torch.nn.Module)
module.mean_bank = []
module.var_bank = []
module.style_cfgs = []
module.gn_weight *= 2
outer.attn_module_list = attn_modules
outer.gn_module_list = gn_modules
def restore(self, model):
if hasattr(self, "control_params"):
del self.control_params
if not hasattr(model, "_original_forward"):
# no such handle, ignore
model.forward = model._original_forward
del model._original_forward