File size: 6,733 Bytes
be40477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations
from typing import Callable

class CallbacksMP:
    ON_CLONE = "on_clone"
    ON_LOAD = "on_load_after"
    ON_DETACH = "on_detach_after"
    ON_CLEANUP = "on_cleanup"
    ON_PRE_RUN = "on_pre_run"
    ON_PREPARE_STATE = "on_prepare_state"
    ON_APPLY_HOOKS = "on_apply_hooks"
    ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches"
    ON_INJECT_MODEL = "on_inject_model"
    ON_EJECT_MODEL = "on_eject_model"

    # callbacks dict is in the format:
    # {"call_type": {"key": [Callable1, Callable2, ...]} }
    @classmethod
    def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]:
        return {}

def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False):
    add_callback_with_key(call_type, None, callback, transformer_options, is_model_options)

def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.setdefault("transformer_options", {})
    callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {})
    c = callbacks.setdefault(call_type, {}).setdefault(key, [])
    c.append(callback)

def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.get("transformer_options", {})
    c_list = []
    callbacks: dict[str, list] = transformer_options.get("callbacks", {})
    c_list.extend(callbacks.get(call_type, {}).get(key, []))
    return c_list

def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.get("transformer_options", {})
    c_list = []
    callbacks: dict[str, list] = transformer_options.get("callbacks", {})
    for c in callbacks.get(call_type, {}).values():
        c_list.extend(c)
    return c_list

class WrappersMP:
    OUTER_SAMPLE = "outer_sample"
    SAMPLER_SAMPLE = "sampler_sample"
    CALC_COND_BATCH = "calc_cond_batch"
    APPLY_MODEL = "apply_model"
    DIFFUSION_MODEL = "diffusion_model"

    # wrappers dict is in the format:
    # {"wrapper_type": {"key": [Callable1, Callable2, ...]} }
    @classmethod
    def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]:
        return {}

def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
    add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options)

def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.setdefault("transformer_options", {})
    wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {})
    w = wrappers.setdefault(wrapper_type, {}).setdefault(key, [])
    w.append(wrapper)

def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.get("transformer_options", {})
    w_list = []
    wrappers: dict[str, list] = transformer_options.get("wrappers", {})
    w_list.extend(wrappers.get(wrapper_type, {}).get(key, []))
    return w_list

def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False):
    if is_model_options:
        transformer_options = transformer_options.get("transformer_options", {})
    w_list = []
    wrappers: dict[str, list] = transformer_options.get("wrappers", {})
    for w in wrappers.get(wrapper_type, {}).values():
        w_list.extend(w)
    return w_list

class WrapperExecutor:
    """Handles call stack of wrappers around a function in an ordered manner."""
    def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
        # NOTE: class_obj exists so that wrappers surrounding a class method can access
        #       the class instance at runtime via executor.class_obj
        self.original = original
        self.class_obj = class_obj
        self.wrappers = wrappers.copy()
        self.idx = idx
        self.is_last = idx == len(wrappers)

    def __call__(self, *args, **kwargs):
        """Calls the next wrapper or original function, whichever is appropriate."""
        new_executor = self._create_next_executor()
        return new_executor.execute(*args, **kwargs)

    def execute(self, *args, **kwargs):
        """Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
        args = list(args)
        kwargs = dict(kwargs)
        if self.is_last:
            return self.original(*args, **kwargs)
        return self.wrappers[self.idx](self, *args, **kwargs)

    def _create_next_executor(self) -> 'WrapperExecutor':
        new_idx = self.idx + 1
        if new_idx > len(self.wrappers):
            raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
        if self.class_obj is None:
            return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
        return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)

    @classmethod
    def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
        return cls(original, class_obj=None, wrappers=wrappers, idx=idx)

    @classmethod
    def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
        return cls(original, class_obj, wrappers, idx=idx)

class PatcherInjection:
    def __init__(self, inject: Callable, eject: Callable):
        self.inject = inject
        self.eject = eject

def copy_nested_dicts(input_dict: dict):
    new_dict = input_dict.copy()
    for key, value in input_dict.items():
        if isinstance(value, dict):
            new_dict[key] = copy_nested_dicts(value)
        elif isinstance(value, list):
            new_dict[key] = value.copy()
    return new_dict

def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
    if copy_dict1:
        merged_dict = copy_nested_dicts(dict1)
    else:
        merged_dict = dict1
    for key, value in dict2.items():
        if isinstance(value, dict):
            curr_value = merged_dict.setdefault(key, {})
            merged_dict[key] = merge_nested_dicts(value, curr_value)
        elif isinstance(value, list):
            merged_dict.setdefault(key, []).extend(value)
        else:
            merged_dict[key] = value
    return merged_dict