ehristoforu's picture
Upload folder using huggingface_hub
0163a2c verified
import gc
import os
import torch
from einops import rearrange
from modules import hashes, shared, sd_models, devices
from modules.devices import cpu, device, torch_gc
from motion_module import MotionWrapper, MotionModuleType
from scripts.animatediff_logger import logger_animatediff as logger
class AnimateDiffMM:
mm_injected = False
def __init__(self):
self.mm: MotionWrapper = None
self.script_dir = None
self.prev_alpha_cumprod = None
self.gn32_original_forward = None
def set_script_dir(self, script_dir):
self.script_dir = script_dir
def get_model_dir(self):
model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model"))
if not model_dir:
model_dir = os.path.join(self.script_dir, "model")
return model_dir
def _load(self, model_name):
model_path = os.path.join(self.get_model_dir(), model_name)
if not os.path.isfile(model_path):
raise RuntimeError("Please download models manually.")
if self.mm is None or self.mm.mm_name != model_name:
logger.info(f"Loading motion module {model_name} from {model_path}")
model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
mm_state_dict = sd_models.read_state_dict(model_path)
model_type = MotionModuleType.get_mm_type(mm_state_dict)
logger.info(f"Guessed {model_name} architecture: {model_type}")
self.mm = MotionWrapper(model_name, model_hash, model_type)
missed_keys = self.mm.load_state_dict(mm_state_dict)
logger.warn(f"Missing keys {missed_keys}")
self.mm.to(device).eval()
if not shared.cmd_opts.no_half:
self.mm.half()
if getattr(devices, "fp8", False):
for module in self.mm.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
module.to(torch.float8_e4m3fn)
def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
if AnimateDiffMM.mm_injected:
logger.info("Motion module already injected. Trying to restore.")
self.restore(sd_model)
unet = sd_model.model.diffusion_model
self._load(model_name)
inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}."
if self.mm.is_v2:
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
elif self.mm.enable_gn_hack():
logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
from sgm.modules.diffusionmodules.util import GroupNorm32
else:
from ldm.modules.diffusionmodules.util import GroupNorm32
self.gn32_original_forward = GroupNorm32.forward
gn32_original_forward = self.gn32_original_forward
def groupnorm32_mm_forward(self, x):
x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
x = gn32_original_forward(self, x)
x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
return x
GroupNorm32.forward = groupnorm32_mm_forward
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
if inject_sdxl and mm_idx >= 6:
break
mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
unet.input_blocks[unet_idx].append(mm_inject)
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
for unet_idx in range(12):
if inject_sdxl and unet_idx >= 9:
break
mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3
mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
unet.output_blocks[unet_idx].insert(-1, mm_inject)
else:
unet.output_blocks[unet_idx].append(mm_inject)
self._set_ddim_alpha(sd_model)
self._set_layer_mapping(sd_model)
AnimateDiffMM.mm_injected = True
logger.info(f"Injection finished.")
def restore(self, sd_model):
if not AnimateDiffMM.mm_injected:
logger.info("Motion module already removed.")
return
inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
self._restore_ddim_alpha(sd_model)
unet = sd_model.model.diffusion_model
logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
if inject_sdxl and unet_idx >= 9:
break
unet.input_blocks[unet_idx].pop(-1)
logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
for unet_idx in range(12):
if inject_sdxl and unet_idx >= 9:
break
if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
unet.output_blocks[unet_idx].pop(-2)
else:
unet.output_blocks[unet_idx].pop(-1)
if self.mm.is_v2:
logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
unet.middle_block.pop(-2)
elif self.mm.enable_gn_hack():
logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
from sgm.modules.diffusionmodules.util import GroupNorm32
else:
from ldm.modules.diffusionmodules.util import GroupNorm32
GroupNorm32.forward = self.gn32_original_forward
self.gn32_original_forward = None
AnimateDiffMM.mm_injected = False
logger.info(f"Removal finished.")
if sd_model.lowvram:
self.unload()
def _set_ddim_alpha(self, sd_model):
logger.info(f"Setting DDIM alpha.")
beta_start = 0.00085
beta_end = 0.020 if self.mm.is_adxl else 0.012
if self.mm.is_adxl:
betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2
else:
betas = torch.linspace(
beta_start,
beta_end,
1000 if sd_model.is_sdxl else sd_model.num_timesteps,
dtype=torch.float32,
device=device,
)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.prev_alpha_cumprod = sd_model.alphas_cumprod
sd_model.alphas_cumprod = alphas_cumprod
def _set_layer_mapping(self, sd_model):
if hasattr(sd_model, 'network_layer_mapping'):
for name, module in self.mm.named_modules():
sd_model.network_layer_mapping[name] = module
module.network_layer_name = name
def _restore_ddim_alpha(self, sd_model):
logger.info(f"Restoring DDIM alpha.")
sd_model.alphas_cumprod = self.prev_alpha_cumprod
self.prev_alpha_cumprod = None
def unload(self):
logger.info("Moving motion module to CPU")
if self.mm is not None:
self.mm.to(cpu)
torch_gc()
gc.collect()
def remove(self):
logger.info("Removing motion module from any memory")
del self.mm
self.mm = None
torch_gc()
gc.collect()
mm_animatediff = AnimateDiffMM()