File size: 8,097 Bytes
0163a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import gc
import os

import torch
from einops import rearrange
from modules import hashes, shared, sd_models, devices
from modules.devices import cpu, device, torch_gc

from motion_module import MotionWrapper, MotionModuleType
from scripts.animatediff_logger import logger_animatediff as logger


class AnimateDiffMM:
    mm_injected = False

    def __init__(self):
        self.mm: MotionWrapper = None
        self.script_dir = None
        self.prev_alpha_cumprod = None
        self.gn32_original_forward = None


    def set_script_dir(self, script_dir):
        self.script_dir = script_dir


    def get_model_dir(self):
        model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model"))
        if not model_dir:
            model_dir = os.path.join(self.script_dir, "model")
        return model_dir


    def _load(self, model_name):
        model_path = os.path.join(self.get_model_dir(), model_name)
        if not os.path.isfile(model_path):
            raise RuntimeError("Please download models manually.")
        if self.mm is None or self.mm.mm_name != model_name:
            logger.info(f"Loading motion module {model_name} from {model_path}")
            model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
            mm_state_dict = sd_models.read_state_dict(model_path)
            model_type = MotionModuleType.get_mm_type(mm_state_dict)
            logger.info(f"Guessed {model_name} architecture: {model_type}")
            self.mm = MotionWrapper(model_name, model_hash, model_type)
            missed_keys = self.mm.load_state_dict(mm_state_dict)
            logger.warn(f"Missing keys {missed_keys}")
        self.mm.to(device).eval()
        if not shared.cmd_opts.no_half:
            self.mm.half()
            if getattr(devices, "fp8", False):
                for module in self.mm.modules():
                    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
                        module.to(torch.float8_e4m3fn)


    def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
        if AnimateDiffMM.mm_injected:
            logger.info("Motion module already injected. Trying to restore.")
            self.restore(sd_model)

        unet = sd_model.model.diffusion_model
        self._load(model_name)
        inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
        sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
        assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}."

        if self.mm.is_v2:
            logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
            unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
        elif self.mm.enable_gn_hack():
            logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
            if self.mm.is_hotshot:
                from sgm.modules.diffusionmodules.util import GroupNorm32
            else:
                from ldm.modules.diffusionmodules.util import GroupNorm32
            self.gn32_original_forward = GroupNorm32.forward
            gn32_original_forward = self.gn32_original_forward

            def groupnorm32_mm_forward(self, x):
                x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
                x = gn32_original_forward(self, x)
                x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
                return x

            GroupNorm32.forward = groupnorm32_mm_forward

        logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
        for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
            if inject_sdxl and mm_idx >= 6:
                break
            mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
            mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
            unet.input_blocks[unet_idx].append(mm_inject)

        logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
        for unet_idx in range(12):
            if inject_sdxl and unet_idx >= 9:
                break
            mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3
            mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
            if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
                unet.output_blocks[unet_idx].insert(-1, mm_inject)
            else:
                unet.output_blocks[unet_idx].append(mm_inject)

        self._set_ddim_alpha(sd_model)
        self._set_layer_mapping(sd_model)
        AnimateDiffMM.mm_injected = True
        logger.info(f"Injection finished.")


    def restore(self, sd_model):
        if not AnimateDiffMM.mm_injected:
            logger.info("Motion module already removed.")
            return

        inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
        sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
        self._restore_ddim_alpha(sd_model)
        unet = sd_model.model.diffusion_model

        logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
        for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
            if inject_sdxl and unet_idx >= 9:
                break
            unet.input_blocks[unet_idx].pop(-1)

        logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
        for unet_idx in range(12):
            if inject_sdxl and unet_idx >= 9:
                break
            if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
                unet.output_blocks[unet_idx].pop(-2)
            else:
                unet.output_blocks[unet_idx].pop(-1)

        if self.mm.is_v2:
            logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
            unet.middle_block.pop(-2)
        elif self.mm.enable_gn_hack():
            logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
            if self.mm.is_hotshot:
                from sgm.modules.diffusionmodules.util import GroupNorm32
            else:
                from ldm.modules.diffusionmodules.util import GroupNorm32
            GroupNorm32.forward = self.gn32_original_forward
            self.gn32_original_forward = None

        AnimateDiffMM.mm_injected = False
        logger.info(f"Removal finished.")
        if sd_model.lowvram:
            self.unload()


    def _set_ddim_alpha(self, sd_model):
        logger.info(f"Setting DDIM alpha.")
        beta_start = 0.00085
        beta_end = 0.020 if self.mm.is_adxl else 0.012
        if self.mm.is_adxl:
            betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2
        else:
            betas = torch.linspace(
                beta_start,
                beta_end,
                1000 if sd_model.is_sdxl else sd_model.num_timesteps,
                dtype=torch.float32,
                device=device,
            )
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.prev_alpha_cumprod = sd_model.alphas_cumprod
        sd_model.alphas_cumprod = alphas_cumprod
    

    def _set_layer_mapping(self, sd_model):
        if hasattr(sd_model, 'network_layer_mapping'):
            for name, module in self.mm.named_modules():
                sd_model.network_layer_mapping[name] = module
                module.network_layer_name = name


    def _restore_ddim_alpha(self, sd_model):
        logger.info(f"Restoring DDIM alpha.")
        sd_model.alphas_cumprod = self.prev_alpha_cumprod
        self.prev_alpha_cumprod = None


    def unload(self):
        logger.info("Moving motion module to CPU")
        if self.mm is not None:
            self.mm.to(cpu)
        torch_gc()
        gc.collect()


    def remove(self):
        logger.info("Removing motion module from any memory")
        del self.mm
        self.mm = None
        torch_gc()
        gc.collect()


mm_animatediff = AnimateDiffMM()