|
import gc |
|
import os |
|
|
|
import torch |
|
from einops import rearrange |
|
from modules import hashes, shared, sd_models, devices |
|
from modules.devices import cpu, device, torch_gc |
|
|
|
from motion_module import MotionWrapper, MotionModuleType |
|
from scripts.animatediff_logger import logger_animatediff as logger |
|
|
|
|
|
class AnimateDiffMM: |
|
mm_injected = False |
|
|
|
def __init__(self): |
|
self.mm: MotionWrapper = None |
|
self.script_dir = None |
|
self.prev_alpha_cumprod = None |
|
self.gn32_original_forward = None |
|
|
|
|
|
def set_script_dir(self, script_dir): |
|
self.script_dir = script_dir |
|
|
|
|
|
def get_model_dir(self): |
|
model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model")) |
|
if not model_dir: |
|
model_dir = os.path.join(self.script_dir, "model") |
|
return model_dir |
|
|
|
|
|
def _load(self, model_name): |
|
model_path = os.path.join(self.get_model_dir(), model_name) |
|
if not os.path.isfile(model_path): |
|
raise RuntimeError("Please download models manually.") |
|
if self.mm is None or self.mm.mm_name != model_name: |
|
logger.info(f"Loading motion module {model_name} from {model_path}") |
|
model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}") |
|
mm_state_dict = sd_models.read_state_dict(model_path) |
|
model_type = MotionModuleType.get_mm_type(mm_state_dict) |
|
logger.info(f"Guessed {model_name} architecture: {model_type}") |
|
self.mm = MotionWrapper(model_name, model_hash, model_type) |
|
missed_keys = self.mm.load_state_dict(mm_state_dict) |
|
logger.warn(f"Missing keys {missed_keys}") |
|
self.mm.to(device).eval() |
|
if not shared.cmd_opts.no_half: |
|
self.mm.half() |
|
if getattr(devices, "fp8", False): |
|
for module in self.mm.modules(): |
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): |
|
module.to(torch.float8_e4m3fn) |
|
|
|
|
|
def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): |
|
if AnimateDiffMM.mm_injected: |
|
logger.info("Motion module already injected. Trying to restore.") |
|
self.restore(sd_model) |
|
|
|
unet = sd_model.model.diffusion_model |
|
self._load(model_name) |
|
inject_sdxl = sd_model.is_sdxl or self.mm.is_xl |
|
sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" |
|
assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}." |
|
|
|
if self.mm.is_v2: |
|
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") |
|
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) |
|
elif self.mm.enable_gn_hack(): |
|
logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.") |
|
if self.mm.is_hotshot: |
|
from sgm.modules.diffusionmodules.util import GroupNorm32 |
|
else: |
|
from ldm.modules.diffusionmodules.util import GroupNorm32 |
|
self.gn32_original_forward = GroupNorm32.forward |
|
gn32_original_forward = self.gn32_original_forward |
|
|
|
def groupnorm32_mm_forward(self, x): |
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=2) |
|
x = gn32_original_forward(self, x) |
|
x = rearrange(x, "b c f h w -> (b f) c h w", b=2) |
|
return x |
|
|
|
GroupNorm32.forward = groupnorm32_mm_forward |
|
|
|
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.") |
|
for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]): |
|
if inject_sdxl and mm_idx >= 6: |
|
break |
|
mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2 |
|
mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1] |
|
unet.input_blocks[unet_idx].append(mm_inject) |
|
|
|
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.") |
|
for unet_idx in range(12): |
|
if inject_sdxl and unet_idx >= 9: |
|
break |
|
mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3 |
|
mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1] |
|
if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11): |
|
unet.output_blocks[unet_idx].insert(-1, mm_inject) |
|
else: |
|
unet.output_blocks[unet_idx].append(mm_inject) |
|
|
|
self._set_ddim_alpha(sd_model) |
|
self._set_layer_mapping(sd_model) |
|
AnimateDiffMM.mm_injected = True |
|
logger.info(f"Injection finished.") |
|
|
|
|
|
def restore(self, sd_model): |
|
if not AnimateDiffMM.mm_injected: |
|
logger.info("Motion module already removed.") |
|
return |
|
|
|
inject_sdxl = sd_model.is_sdxl or self.mm.is_xl |
|
sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" |
|
self._restore_ddim_alpha(sd_model) |
|
unet = sd_model.model.diffusion_model |
|
|
|
logger.info(f"Removing motion module from {sd_ver} UNet input blocks.") |
|
for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]: |
|
if inject_sdxl and unet_idx >= 9: |
|
break |
|
unet.input_blocks[unet_idx].pop(-1) |
|
|
|
logger.info(f"Removing motion module from {sd_ver} UNet output blocks.") |
|
for unet_idx in range(12): |
|
if inject_sdxl and unet_idx >= 9: |
|
break |
|
if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11): |
|
unet.output_blocks[unet_idx].pop(-2) |
|
else: |
|
unet.output_blocks[unet_idx].pop(-1) |
|
|
|
if self.mm.is_v2: |
|
logger.info(f"Removing motion module from {sd_ver} UNet middle block.") |
|
unet.middle_block.pop(-2) |
|
elif self.mm.enable_gn_hack(): |
|
logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.") |
|
if self.mm.is_hotshot: |
|
from sgm.modules.diffusionmodules.util import GroupNorm32 |
|
else: |
|
from ldm.modules.diffusionmodules.util import GroupNorm32 |
|
GroupNorm32.forward = self.gn32_original_forward |
|
self.gn32_original_forward = None |
|
|
|
AnimateDiffMM.mm_injected = False |
|
logger.info(f"Removal finished.") |
|
if sd_model.lowvram: |
|
self.unload() |
|
|
|
|
|
def _set_ddim_alpha(self, sd_model): |
|
logger.info(f"Setting DDIM alpha.") |
|
beta_start = 0.00085 |
|
beta_end = 0.020 if self.mm.is_adxl else 0.012 |
|
if self.mm.is_adxl: |
|
betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2 |
|
else: |
|
betas = torch.linspace( |
|
beta_start, |
|
beta_end, |
|
1000 if sd_model.is_sdxl else sd_model.num_timesteps, |
|
dtype=torch.float32, |
|
device=device, |
|
) |
|
alphas = 1.0 - betas |
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
self.prev_alpha_cumprod = sd_model.alphas_cumprod |
|
sd_model.alphas_cumprod = alphas_cumprod |
|
|
|
|
|
def _set_layer_mapping(self, sd_model): |
|
if hasattr(sd_model, 'network_layer_mapping'): |
|
for name, module in self.mm.named_modules(): |
|
sd_model.network_layer_mapping[name] = module |
|
module.network_layer_name = name |
|
|
|
|
|
def _restore_ddim_alpha(self, sd_model): |
|
logger.info(f"Restoring DDIM alpha.") |
|
sd_model.alphas_cumprod = self.prev_alpha_cumprod |
|
self.prev_alpha_cumprod = None |
|
|
|
|
|
def unload(self): |
|
logger.info("Moving motion module to CPU") |
|
if self.mm is not None: |
|
self.mm.to(cpu) |
|
torch_gc() |
|
gc.collect() |
|
|
|
|
|
def remove(self): |
|
logger.info("Removing motion module from any memory") |
|
del self.mm |
|
self.mm = None |
|
torch_gc() |
|
gc.collect() |
|
|
|
|
|
mm_animatediff = AnimateDiffMM() |
|
|