File size: 10,802 Bytes
80c0bfe
 
 
 
194a41e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c0bfe
b96c8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15347cd
b96c8c5
 
a84a952
e4a20eb
b96c8c5
 
 
bcb57a1
b96c8c5
 
ec25ef6
071e791
 
b96c8c5
 
937a5c6
cec477c
 
cbbd444
 
 
 
0a6e1c9
b96c8c5
a447492
071e791
b96c8c5
 
 
 
194a41e
951c0bc
 
 
 
 
 
 
 
b96c8c5
951c0bc
 
 
 
 
 
 
c643307
 
 
 
cbbd444
c643307
cbbd444
 
c643307
 
b96c8c5
951c0bc
 
 
1ff2dfa
951c0bc
 
c6a81f6
951c0bc
f7b4ac7
e4a20eb
951c0bc
 
ec25ef6
071e791
 
951c0bc
 
a447492
951c0bc
 
 
cec477c
 
586b631
 
 
cbbd444
0a6e1c9
 
951c0bc
 
b96c8c5
 
 
 
 
 
 
 
 
 
80c0bfe
b96c8c5
951c0bc
b96c8c5
a5097b6
99b5399
 
b96c8c5
58f80cc
b96c8c5
58f80cc
b96c8c5
58f80cc
b96c8c5
 
bcb57a1
b96c8c5
bcb57a1
b96c8c5
bcb57a1
b96c8c5
bcb57a1
cec477c
bcb57a1
b96c8c5
 
bcb57a1
b96c8c5
bcb57a1
f7b4ac7
bcb57a1
58f80cc
cec477c
bcb57a1
b96c8c5
 
 
bcb57a1
b96c8c5
f7b4ac7
e4a20eb
f7b4ac7
cec477c
 
 
29cefb0
ec25ef6
 
071e791
 
15347cd
cec477c
e18e8e2
bcb57a1
c6a81f6
 
 
e4a20eb
c6a81f6
 
bcb57a1
c6a81f6
071e791
 
 
15347cd
cec477c
bcb57a1
b96c8c5
bcb57a1
b96c8c5
cec477c
 
 
b96c8c5
cec477c
c643307
bcb57a1
071e791
b96c8c5
bcb57a1
 
58f80cc
586b631
8e3f95a
cec477c
 
cbbd444
 
 
 
 
 
 
 
 
 
 
 
0a6e1c9
 
 
 
 
 
8e3f95a
bcb57a1
 
 
b96c8c5
bcb57a1
b96c8c5
 
 
0a6e1c9
 
c6a81f6
bcb57a1
c6a81f6
e18e8e2
bcb57a1
c6a81f6
 
a5752aa
 
bcb57a1
c6a81f6
bcb57a1
c6a81f6
bcb57a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import gradio as gr
import base64
import json
import torch
from gradio import Request
from gradio.context import Context


def persist(component):
    sessions = {}

    def resume_session(value, request: Request):
        return sessions.get(request.username, value)

    def update_session(value, request: Request):
        sessions[request.username] = value

    Context.root_block.load(resume_session, inputs=[component], outputs=component)
    component.change(update_session, inputs=[component])

    return component

def get_initial_config():
    
    device = None
    data_type = 'float16'
    allow_tensorfloat32 = "False"
    if torch.cuda.is_available():
        device = "cuda" 
        data_type = 'bfloat16'
        allow_tensorfloat32 = "True"
    elif torch.backends.mps.is_available():
        device = "mps" 
    else:
        device = "cpu"

    config = {
        "device": device,
        "model": None,
        "cpu_offload": "False",
        "scheduler": None,
        "variant": None,
        "attention_slicing": "False",
        "pre_compile_unet": "False",
        "allow_tensorfloat32": allow_tensorfloat32,
        "use_safetensors": "False",
        "data_type": data_type,
        "refiner": "none",
        "safety_checker": "False",
        "requires_safety_checker": "False",
        "auto_encoder": None,
        "enable_vae_slicing": "True",
        "enable_vae_tiling": "True",
        "manual_seed": 42,
        "inference_steps": 10,
        "guidance_scale": 5,
        "adapter_textual_inversion": None,
        "adapter_textual_inversion_token": None,
        "adapter_lora": [],
        "adapter_lora_token": [],
        "adapter_lora_weight": [],
        "adapter_lora_balancing": {},
        "lora_scale": 0.5,
        "prompt": 'A white rabbit',
        "trigger_token": '',
        "negative_prompt": 'lowres, cropped, worst quality, low quality',
    }
    
    return config

def get_config_from_url(initial_config, request: Request):
    
    encoded_params = request.request.query_params.get('config')
    return_config = {}

    # get configuration from URL if GET parameter `share` is set
    if encoded_params is not None:
        decoded_params = base64.b64decode(encoded_params)
        decoded_params = decoded_params.decode('utf-8')
        decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
        dict_params = json.loads(decoded_params)
        
        return_config = dict_params
            
    # otherwise use default initial config
    else:
        
        # check if a cookie exists for our initial parameters
        for key in initial_config.keys():
            if key in request.cookies:
                value = request.cookies[key]
                # transform empty values to a "Python-like" None
                if value == 'null' or value == '': value = None
                # if value expected to be a list, transform the string to list
                if type(initial_config[key]) == list: value = json.loads(value)
                initial_config[key] = value
        
        return_config = initial_config

    return [return_config['model'],
            return_config['device'],
            return_config['cpu_offload'],
            return_config['use_safetensors'],
            return_config['data_type'],
            return_config['refiner'],
            return_config['variant'],
            return_config['attention_slicing'],
            return_config['pre_compile_unet'],
            return_config['safety_checker'],
            return_config['requires_safety_checker'],
            return_config['auto_encoder'],
            return_config['enable_vae_slicing'],
            return_config['enable_vae_tiling'],
            return_config['scheduler'],
            return_config['prompt'],
            return_config['trigger_token'],
            return_config['negative_prompt'],
            return_config['inference_steps'],
            return_config['manual_seed'],
            return_config['guidance_scale'],
            return_config['adapter_textual_inversion'],
            return_config['adapter_textual_inversion_token'],
            return_config['adapter_lora'],
            return_config['adapter_lora_token'],
            return_config['adapter_lora_weight'],
            return_config['adapter_lora_balancing'],
            return_config['lora_scale']
            ]
        
def load_app_config():
    try:
        with open('appConfig.json', 'r') as f:
            appConfig = json.load(f)
    except FileNotFoundError:
        print("App config file not found.")
    except json.JSONDecodeError:
        print("Error decoding JSON in app config file.")
    except Exception as e:
        print("An error occurred while loading app config:", str(e))
        
    return appConfig

def set_config(config, key, value):
    
    if str(value).lower() == 'null' or str(value).lower() == 'none': value = '' 
    
    config[key] = value

    return config

def assemble_code(str_config):

    config = str_config
    
    code = []
    
    code.append(f'''device = "{config['device']}"''')
    if config['data_type'] == "bfloat16":
        code.append('data_type = torch.bfloat16')
    else:
        code.append('data_type = torch.float16')
        
    code.append(f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}')
    
    if str(config["variant"]) == 'None':
        code.append(f'variant = {config["variant"]}')
    else:
        code.append(f'variant = "{config["variant"]}"')

    code.append(f'''use_safetensors = {config["use_safetensors"]}''')
        
    # INIT PIPELINE
    code.append(f'''pipeline = DiffusionPipeline.from_pretrained(
            "{config['model']}", 
            use_safetensors=use_safetensors, 
            torch_dtype=data_type, 
            variant=variant).to(device)''')
    
    if str(config["attention_slicing"]).lower() != 'false': code.append("pipeline.enable_attention_slicing()")
    if str(config["pre_compile_unet"]).lower() != 'false': code.append("pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)")

    if str(config["cpu_offload"]).lower() != 'false': code.append("pipeline.enable_model_cpu_offload()")
    
    # AUTO ENCODER
    if str(config["auto_encoder"]).lower() != 'none' and str(config["auto_encoder"]).lower() != 'null' and str(config["auto_encoder"]).lower() != '':
        code.append(f'pipeline.vae = AutoencoderKL.from_pretrained("{config["auto_encoder"]}", torch_dtype=data_type).to(device)')
    
    if str(config["enable_vae_slicing"]).lower() != 'false': code.append("pipeline.enable_vae_slicing()")
    if str(config["enable_vae_tiling"]).lower() != 'false': code.append("pipeline.enable_vae_tiling()")
    
    # INIT REFINER
    if str(config['refiner']).lower() != 'none':
        code.append(f'''refiner = DiffusionPipeline.from_pretrained(
                "{config['refiner']}",
                text_encoder_2 = base.text_encoder_2,
                vae = base.vae,
                torch_dtype = data_t    ype,
                use_safetensors = use_safetensors,
                variant=variant,
            ).to(device)''')
        
        if str(config["cpu_offload"]).lower() != 'false': code.append("refiner.enable_model_cpu_offload()")
        if str(config["enable_vae_slicing"]).lower() != 'false': code.append("refiner.enable_vae_slicing()")
        if str(config["enable_vae_tiling"]).lower() != 'false': code.append("refiner.enable_vae_tiling()")
        
    # SAFETY CHECKER
    code.append(f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}')
    if str(config["safety_checker"]).lower() == 'false':
        code.append(f'pipeline.safety_checker = None')
    
    # SCHEDULER/SOLVER
    if str(config["scheduler"]).lower() != 'none':
        code.append(f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)')
    
    # MANUAL SEED/GENERATOR
    if config['manual_seed'] is None or config['manual_seed'] == '' or int(config['manual_seed']) < 0:
        code.append(f'# manual_seed = {config["manual_seed"]}')
        code.append(f'generator = None')
    else:
        code.append(f'manual_seed = {config["manual_seed"]}')
        code.append(f'generator = torch.manual_seed(manual_seed)')

    # ADAPTER
    if str(config["adapter_textual_inversion"]).lower() != 'none' and str(config["adapter_textual_inversion"]).lower() != 'null' and str(config["adapter_textual_inversion"]).lower() != '':
        code.append(f'pipeline.load_textual_inversion("{config["adapter_textual_inversion"]}", token="{config["adapter_textual_inversion_token"]}")')

    if len(config["adapter_lora"]) > 0 and len(config["adapter_lora"]) == len(config["adapter_lora_weight"]):
        adapter_lora_balancing = []
        for adapter_lora_index, adapter_lora in enumerate(config["adapter_lora"]):
            if str(config["adapter_lora_weight"][adapter_lora_index]).lower() != 'none':
                code.append(f'pipeline.load_lora_weights("{adapter_lora}", weight_name="{config["adapter_lora_weight"][adapter_lora_index]}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
            else:
                code.append(f'pipeline.load_lora_weights("{adapter_lora}", adapter_name="{config["adapter_lora_token"][adapter_lora_index]}")')
            adapter_lora_balancing.append(config["adapter_lora_balancing"][adapter_lora])
    
        code.append(f'adapter_weights = {adapter_lora_balancing}')
        code.append(f'pipeline.set_adapters({config["adapter_lora_token"]}, adapter_weights=adapter_weights)')

        cross_attention_kwargs = '{"scale": ' + config["lora_scale"] + '}'
        
    else:
        
        cross_attention_kwargs = 'None'
        
    code.append(f'prompt = "{config["prompt"]} {config["trigger_token"]} {config["adapter_textual_inversion_token"]} {", ".join(config["adapter_lora_token"])}"')
    code.append(f'negative_prompt = "{config["negative_prompt"]}"')
    code.append(f'inference_steps = {config["inference_steps"]}')
    code.append(f'guidance_scale = {config["guidance_scale"]}')
    
    code.append(f'''image = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        generator=generator,
        num_inference_steps=inference_steps,
        cross_attention_kwargs={cross_attention_kwargs},
        guidance_scale=guidance_scale).images
        ''')

    if str(config['refiner']).lower() != 'none':
        code.append(f'''image = refiner(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=inference_steps,
            image=image
            ).images[0]''')
        
    code.append('image[0]')
    
    return '\r\n'.join(code)