openfree commited on
Commit
3eb69dd
1 Parent(s): ff18ed4

Delete app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +0 -831
app-backup.py DELETED
@@ -1,831 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import json
4
- import logging
5
- import torch
6
- from PIL import Image
7
- import spaces
8
- from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image, FluxControlNetModel
9
- from diffusers.pipelines import FluxControlNetPipeline
10
- from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
11
- from diffusers.utils import load_image
12
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
13
- import copy
14
- import random
15
- import time
16
- import requests
17
- import pandas as pd
18
- from transformers import pipeline
19
- from gradio_imageslider import ImageSlider
20
- import numpy as np
21
- import warnings
22
-
23
-
24
- huggingface_token = os.getenv("HUGGINFACE_TOKEN")
25
-
26
-
27
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
28
-
29
-
30
-
31
- #Load prompts for randomization
32
- df = pd.read_csv('prompts.csv', header=None)
33
- prompt_values = df.values.flatten()
34
-
35
- # Load LoRAs from JSON file
36
- with open('loras.json', 'r') as f:
37
- loras = json.load(f)
38
-
39
- # Initialize the base model
40
- dtype = torch.bfloat16
41
-
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
-
44
- # 공통 FLUX 모델 로드
45
- base_model = "black-forest-labs/FLUX.1-dev"
46
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
47
-
48
- # LoRA를 위한 설정
49
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
50
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
51
-
52
- # Image-to-Image 파이프라인 설정
53
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
54
- base_model,
55
- vae=good_vae,
56
- transformer=pipe.transformer,
57
- text_encoder=pipe.text_encoder,
58
- tokenizer=pipe.tokenizer,
59
- text_encoder_2=pipe.text_encoder_2,
60
- tokenizer_2=pipe.tokenizer_2,
61
- torch_dtype=dtype
62
- ).to(device)
63
-
64
- # Upscale을 위한 ControlNet 설정
65
- controlnet = FluxControlNetModel.from_pretrained(
66
- "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
67
- ).to(device)
68
-
69
- # Upscale 파이프라인 설정 (기존 pipe 재사용)
70
- pipe_upscale = FluxControlNetPipeline(
71
- vae=pipe.vae,
72
- text_encoder=pipe.text_encoder,
73
- text_encoder_2=pipe.text_encoder_2,
74
- tokenizer=pipe.tokenizer,
75
- tokenizer_2=pipe.tokenizer_2,
76
- transformer=pipe.transformer,
77
- scheduler=pipe.scheduler,
78
- controlnet=controlnet
79
- ).to(device)
80
-
81
- MAX_SEED = 2**32 - 1
82
- MAX_PIXEL_BUDGET = 1024 * 1024
83
-
84
- pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
85
-
86
- class calculateDuration:
87
- def __init__(self, activity_name=""):
88
- self.activity_name = activity_name
89
-
90
- def __enter__(self):
91
- self.start_time = time.time()
92
- return self
93
-
94
- def __exit__(self, exc_type, exc_value, traceback):
95
- self.end_time = time.time()
96
- self.elapsed_time = self.end_time - self.start_time
97
- if self.activity_name:
98
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
99
- else:
100
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
101
-
102
- def download_file(url, directory=None):
103
- if directory is None:
104
- directory = os.getcwd() # Use current working directory if not specified
105
-
106
- # Get the filename from the URL
107
- filename = url.split('/')[-1]
108
-
109
- # Full path for the downloaded file
110
- filepath = os.path.join(directory, filename)
111
-
112
- # Download the file
113
- response = requests.get(url)
114
- response.raise_for_status() # Raise an exception for bad status codes
115
-
116
- # Write the content to the file
117
- with open(filepath, 'wb') as file:
118
- file.write(response.content)
119
-
120
- return filepath
121
-
122
- def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
123
- selected_index = evt.index
124
- selected_indices = selected_indices or []
125
- if selected_index in selected_indices:
126
- selected_indices.remove(selected_index)
127
- else:
128
- if len(selected_indices) < 2:
129
- selected_indices.append(selected_index)
130
- else:
131
- gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
132
- return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
133
-
134
- selected_info_1 = "Select a LoRA 1"
135
- selected_info_2 = "Select a LoRA 2"
136
- lora_scale_1 = 1.15
137
- lora_scale_2 = 1.15
138
- lora_image_1 = None
139
- lora_image_2 = None
140
- if len(selected_indices) >= 1:
141
- lora1 = loras_state[selected_indices[0]]
142
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
143
- lora_image_1 = lora1['image']
144
- if len(selected_indices) >= 2:
145
- lora2 = loras_state[selected_indices[1]]
146
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
147
- lora_image_2 = lora2['image']
148
-
149
- if selected_indices:
150
- last_selected_lora = loras_state[selected_indices[-1]]
151
- new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
152
- else:
153
- new_placeholder = "Type a prompt after selecting a LoRA"
154
-
155
- return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
156
-
157
- def remove_lora_1(selected_indices, loras_state):
158
- if len(selected_indices) >= 1:
159
- selected_indices.pop(0)
160
- selected_info_1 = "Select a LoRA 1"
161
- selected_info_2 = "Select a LoRA 2"
162
- lora_scale_1 = 1.15
163
- lora_scale_2 = 1.15
164
- lora_image_1 = None
165
- lora_image_2 = None
166
- if len(selected_indices) >= 1:
167
- lora1 = loras_state[selected_indices[0]]
168
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
169
- lora_image_1 = lora1['image']
170
- if len(selected_indices) >= 2:
171
- lora2 = loras_state[selected_indices[1]]
172
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
173
- lora_image_2 = lora2['image']
174
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
175
-
176
- def remove_lora_2(selected_indices, loras_state):
177
- if len(selected_indices) >= 2:
178
- selected_indices.pop(1)
179
- selected_info_1 = "Select LoRA 1"
180
- selected_info_2 = "Select LoRA 2"
181
- lora_scale_1 = 1.15
182
- lora_scale_2 = 1.15
183
- lora_image_1 = None
184
- lora_image_2 = None
185
- if len(selected_indices) >= 1:
186
- lora1 = loras_state[selected_indices[0]]
187
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
188
- lora_image_1 = lora1['image']
189
- if len(selected_indices) >= 2:
190
- lora2 = loras_state[selected_indices[1]]
191
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
192
- lora_image_2 = lora2['image']
193
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
194
-
195
- def randomize_loras(selected_indices, loras_state):
196
- try:
197
- if len(loras_state) < 2:
198
- raise gr.Error("Not enough LoRAs to randomize.")
199
- selected_indices = random.sample(range(len(loras_state)), 2)
200
- lora1 = loras_state[selected_indices[0]]
201
- lora2 = loras_state[selected_indices[1]]
202
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
203
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
204
- lora_scale_1 = 1.15
205
- lora_scale_2 = 1.15
206
- lora_image_1 = lora1['image']
207
- lora_image_2 = lora2['image']
208
- random_prompt = random.choice(prompt_values)
209
- return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
210
- except Exception as e:
211
- print(f"Error in randomize_loras: {str(e)}")
212
- return "Error", "Error", [], 1.15, 1.15, None, None, ""
213
-
214
- def add_custom_lora(custom_lora, selected_indices, current_loras):
215
- if custom_lora:
216
- try:
217
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
218
- print(f"Loaded custom LoRA: {repo}")
219
- existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
220
- if existing_item_index is None:
221
- if repo.endswith(".safetensors") and repo.startswith("http"):
222
- repo = download_file(repo)
223
- new_item = {
224
- "image": image if image else "/home/user/app/custom.png",
225
- "title": title,
226
- "repo": repo,
227
- "weights": path,
228
- "trigger_word": trigger_word
229
- }
230
- print(f"New LoRA: {new_item}")
231
- existing_item_index = len(current_loras)
232
- current_loras.append(new_item)
233
-
234
- # Update gallery
235
- gallery_items = [(item["image"], item["title"]) for item in current_loras]
236
- # Update selected_indices if there's room
237
- if len(selected_indices) < 2:
238
- selected_indices.append(existing_item_index)
239
- else:
240
- gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
241
-
242
- # Update selected_info and images
243
- selected_info_1 = "Select a LoRA 1"
244
- selected_info_2 = "Select a LoRA 2"
245
- lora_scale_1 = 1.15
246
- lora_scale_2 = 1.15
247
- lora_image_1 = None
248
- lora_image_2 = None
249
- if len(selected_indices) >= 1:
250
- lora1 = current_loras[selected_indices[0]]
251
- selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
252
- lora_image_1 = lora1['image'] if lora1['image'] else None
253
- if len(selected_indices) >= 2:
254
- lora2 = current_loras[selected_indices[1]]
255
- selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
256
- lora_image_2 = lora2['image'] if lora2['image'] else None
257
- print("Finished adding custom LoRA")
258
- return (
259
- current_loras,
260
- gr.update(value=gallery_items),
261
- selected_info_1,
262
- selected_info_2,
263
- selected_indices,
264
- lora_scale_1,
265
- lora_scale_2,
266
- lora_image_1,
267
- lora_image_2
268
- )
269
- except Exception as e:
270
- print(e)
271
- gr.Warning(str(e))
272
- return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
273
- else:
274
- return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
275
-
276
- def remove_custom_lora(selected_indices, current_loras):
277
- if current_loras:
278
- custom_lora_repo = current_loras[-1]['repo']
279
- # Remove from loras list
280
- current_loras = current_loras[:-1]
281
- # Remove from selected_indices if selected
282
- custom_lora_index = len(current_loras)
283
- if custom_lora_index in selected_indices:
284
- selected_indices.remove(custom_lora_index)
285
- # Update gallery
286
- gallery_items = [(item["image"], item["title"]) for item in current_loras]
287
- # Update selected_info and images
288
- selected_info_1 = "Select a LoRA 1"
289
- selected_info_2 = "Select a LoRA 2"
290
- lora_scale_1 = 1.15
291
- lora_scale_2 = 1.15
292
- lora_image_1 = None
293
- lora_image_2 = None
294
- if len(selected_indices) >= 1:
295
- lora1 = current_loras[selected_indices[0]]
296
- selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
297
- lora_image_1 = lora1['image']
298
- if len(selected_indices) >= 2:
299
- lora2 = current_loras[selected_indices[1]]
300
- selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
301
- lora_image_2 = lora2['image']
302
- return (
303
- current_loras,
304
- gr.update(value=gallery_items),
305
- selected_info_1,
306
- selected_info_2,
307
- selected_indices,
308
- lora_scale_1,
309
- lora_scale_2,
310
- lora_image_1,
311
- lora_image_2
312
- )
313
-
314
- @spaces.GPU(duration=75)
315
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
316
- print("Generating image...")
317
- pipe.to("cuda")
318
- generator = torch.Generator(device="cuda").manual_seed(seed)
319
- with calculateDuration("Generating image"):
320
- # Generate image
321
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
322
- prompt=prompt_mash,
323
- num_inference_steps=steps,
324
- guidance_scale=cfg_scale,
325
- width=width,
326
- height=height,
327
- generator=generator,
328
- joint_attention_kwargs={"scale": 1.0},
329
- output_type="pil",
330
- good_vae=good_vae,
331
- ):
332
- yield img
333
-
334
- @spaces.GPU(duration=75)
335
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
336
- pipe_i2i.to("cuda")
337
- generator = torch.Generator(device="cuda").manual_seed(seed)
338
- image_input = load_image(image_input_path)
339
- final_image = pipe_i2i(
340
- prompt=prompt_mash,
341
- image=image_input,
342
- strength=image_strength,
343
- num_inference_steps=steps,
344
- guidance_scale=cfg_scale,
345
- width=width,
346
- height=height,
347
- generator=generator,
348
- joint_attention_kwargs={"scale": 1.0},
349
- output_type="pil",
350
- ).images[0]
351
- return final_image
352
-
353
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
354
- try:
355
- # 한글 감지 및 번역
356
- if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in prompt):
357
- translated = translator(prompt, max_length=512)[0]['translation_text']
358
- print(f"Original prompt: {prompt}")
359
- print(f"Translated prompt: {translated}")
360
- prompt = translated
361
-
362
- if not selected_indices:
363
- raise gr.Error("You must select at least one LoRA before proceeding.")
364
-
365
- selected_loras = [loras_state[idx] for idx in selected_indices]
366
-
367
- # Build the prompt with trigger words
368
- prepends = []
369
- appends = []
370
- for lora in selected_loras:
371
- trigger_word = lora.get('trigger_word', '')
372
- if trigger_word:
373
- if lora.get("trigger_position") == "prepend":
374
- prepends.append(trigger_word)
375
- else:
376
- appends.append(trigger_word)
377
- prompt_mash = " ".join(prepends + [prompt] + appends)
378
- print("Prompt Mash: ", prompt_mash)
379
-
380
- # Unload previous LoRA weights
381
- with calculateDuration("Unloading LoRA"):
382
- pipe.unload_lora_weights()
383
- pipe_i2i.unload_lora_weights()
384
-
385
- print(pipe.get_active_adapters())
386
- # Load LoRA weights with respective scales
387
- lora_names = []
388
- lora_weights = []
389
- with calculateDuration("Loading LoRA weights"):
390
- for idx, lora in enumerate(selected_loras):
391
- lora_name = f"lora_{idx}"
392
- lora_names.append(lora_name)
393
- lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
394
- lora_path = lora['repo']
395
- weight_name = lora.get("weights")
396
- print(f"Lora Path: {lora_path}")
397
- if image_input is not None:
398
- if weight_name:
399
- pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
400
- else:
401
- pipe_i2i.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
402
- else:
403
- if weight_name:
404
- pipe.load_lora_weights(lora_path, weight_name=weight_name, low_cpu_mem_usage=True, adapter_name=lora_name)
405
- else:
406
- pipe.load_lora_weights(lora_path, low_cpu_mem_usage=True, adapter_name=lora_name)
407
- print("Loaded LoRAs:", lora_names)
408
- print("Adapter weights:", lora_weights)
409
- if image_input is not None:
410
- pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
411
- else:
412
- pipe.set_adapters(lora_names, adapter_weights=lora_weights)
413
- print(pipe.get_active_adapters())
414
- # Set random seed for reproducibility
415
- with calculateDuration("Randomizing seed"):
416
- if randomize_seed:
417
- seed = random.randint(0, MAX_SEED)
418
-
419
- # Generate image
420
- if image_input is not None:
421
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
422
- else:
423
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
424
- final_image = None
425
- step_counter = 0
426
- for image in image_generator:
427
- step_counter += 1
428
- final_image = image
429
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
430
- yield image, seed, gr.update(value=progress_bar, visible=True)
431
-
432
-
433
-
434
- if final_image is None:
435
- raise Exception("Failed to generate image")
436
-
437
- return final_image, seed, gr.update(visible=False)
438
- except Exception as e:
439
- print(f"Error in run_lora: {str(e)}")
440
- return None, seed, gr.update(visible=False)
441
-
442
-
443
-
444
- run_lora.zerogpu = True
445
-
446
- def get_huggingface_safetensors(link):
447
- split_link = link.split("/")
448
- if len(split_link) == 2:
449
- model_card = ModelCard.load(link)
450
- base_model = model_card.data.get("base_model")
451
- print(f"Base model: {base_model}")
452
- if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
453
- raise Exception("Not a FLUX LoRA!")
454
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
455
- trigger_word = model_card.data.get("instance_prompt", "")
456
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
457
- fs = HfFileSystem()
458
- safetensors_name = None
459
- try:
460
- list_of_files = fs.ls(link, detail=False)
461
- for file in list_of_files:
462
- if file.endswith(".safetensors"):
463
- safetensors_name = file.split("/")[-1]
464
- if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
465
- image_elements = file.split("/")
466
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
467
- except Exception as e:
468
- print(e)
469
- raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
470
- if not safetensors_name:
471
- raise gr.Error("No *.safetensors file found in the repository")
472
- return split_link[1], link, safetensors_name, trigger_word, image_url
473
- else:
474
- raise gr.Error("Invalid Hugging Face repository link")
475
-
476
- def check_custom_model(link):
477
- if link.endswith(".safetensors"):
478
- # Treat as direct link to the LoRA weights
479
- title = os.path.basename(link)
480
- repo = link
481
- path = None # No specific weight name
482
- trigger_word = ""
483
- image_url = None
484
- return title, repo, path, trigger_word, image_url
485
- elif link.startswith("https://"):
486
- if "huggingface.co" in link:
487
- link_split = link.split("huggingface.co/")
488
- return get_huggingface_safetensors(link_split[1])
489
- else:
490
- raise Exception("Unsupported URL")
491
- else:
492
- # Assume it's a Hugging Face model path
493
- return get_huggingface_safetensors(link)
494
-
495
- def update_history(new_image, history):
496
- """Updates the history gallery with the new image."""
497
- if history is None:
498
- history = []
499
- if new_image is not None:
500
- history.insert(0, new_image)
501
- return history
502
-
503
- css = '''
504
- #gen_btn{height: 100%}
505
- #title{text-align: center}
506
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
507
- #title img{width: 100px; margin-right: 0.25em}
508
- #gallery .grid-wrap{height: 5vh}
509
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
510
- .custom_lora_card{margin-bottom: 1em}
511
- .card_internal{display: flex;height: 100px;margin-top: .5em}
512
- .card_internal img{margin-right: 1em}
513
- .styler{--form-gap-width: 0px !important}
514
- #progress{height:30px}
515
- #progress .generating{display:none}
516
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
517
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
518
- #component-8, .button_total{height: 100%; align-self: stretch;}
519
- #loaded_loras [data-testid="block-info"]{font-size:80%}
520
- #custom_lora_structure{background: var(--block-background-fill)}
521
- #custom_lora_btn{margin-top: auto;margin-bottom: 11px}
522
- #random_btn{font-size: 300%}
523
- #component-11{align-self: stretch;}
524
- footer {visibility: hidden;}
525
- '''
526
-
527
- # 업스케일 관련 함수 추가
528
- def process_input(input_image, upscale_factor, **kwargs):
529
- w, h = input_image.size
530
- w_original, h_original = w, h
531
- aspect_ratio = w / h
532
-
533
- was_resized = False
534
-
535
- max_size = int(np.sqrt(MAX_PIXEL_BUDGET / (upscale_factor ** 2)))
536
- if w > max_size or h > max_size:
537
- if w > h:
538
- w_new = max_size
539
- h_new = int(w_new / aspect_ratio)
540
- else:
541
- h_new = max_size
542
- w_new = int(h_new * aspect_ratio)
543
-
544
- input_image = input_image.resize((w_new, h_new), Image.LANCZOS)
545
- was_resized = True
546
- gr.Info(f"Input image resized to {w_new}x{h_new} to fit within pixel budget after upscaling.")
547
-
548
- # resize to multiple of 8
549
- w, h = input_image.size
550
- w = w - w % 8
551
- h = h - h % 8
552
-
553
- return input_image.resize((w, h)), w_original, h_original, was_resized
554
-
555
- from PIL import Image
556
- import numpy as np
557
-
558
- @spaces.GPU
559
- def infer_upscale(
560
- seed,
561
- randomize_seed,
562
- input_image,
563
- num_inference_steps,
564
- upscale_factor,
565
- controlnet_conditioning_scale,
566
- progress=gr.Progress(track_tqdm=True),
567
- ):
568
- if input_image is None:
569
- return None, seed, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=True, value="Please upload an image for upscaling.")
570
-
571
- try:
572
- if randomize_seed:
573
- seed = random.randint(0, MAX_SEED)
574
-
575
- input_image, w_original, h_original, was_resized = process_input(input_image, upscale_factor)
576
-
577
- # rescale with upscale factor
578
- w, h = input_image.size
579
- control_image = input_image.resize((w * upscale_factor, h * upscale_factor), Image.LANCZOS)
580
-
581
- generator = torch.Generator(device=device).manual_seed(seed)
582
-
583
- gr.Info("Upscaling image...")
584
- # 모든 텐서를 동일한 디바이스로 이동
585
- pipe_upscale.to(device)
586
-
587
- # Ensure the image is in RGB format
588
- if control_image.mode != 'RGB':
589
- control_image = control_image.convert('RGB')
590
-
591
- # Convert to tensor and add batch dimension
592
- control_image = torch.from_numpy(np.array(control_image)).permute(2, 0, 1).float().unsqueeze(0).to(device) / 255.0
593
-
594
- with torch.no_grad():
595
- image = pipe_upscale(
596
- prompt="",
597
- control_image=control_image,
598
- controlnet_conditioning_scale=controlnet_conditioning_scale,
599
- num_inference_steps=num_inference_steps,
600
- guidance_scale=3.5,
601
- generator=generator,
602
- ).images[0]
603
-
604
- # Convert the image back to PIL Image
605
- if isinstance(image, torch.Tensor):
606
- image = image.cpu().permute(1, 2, 0).numpy()
607
-
608
- # Ensure the image data is in the correct range
609
- image = np.clip(image * 255, 0, 255).astype(np.uint8)
610
- image = Image.fromarray(image)
611
-
612
- if was_resized:
613
- gr.Info(
614
- f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size."
615
- )
616
- image = image.resize((w_original * upscale_factor, h_original * upscale_factor), Image.LANCZOS)
617
-
618
- return image, seed, num_inference_steps, upscale_factor, controlnet_conditioning_scale, gr.update(), gr.update(visible=False)
619
- except Exception as e:
620
- print(f"Error in infer_upscale: {str(e)}")
621
- import traceback
622
- traceback.print_exc()
623
- return None, seed, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=True, value=f"Error: {str(e)}")
624
-
625
- def check_upscale_input(input_image, *args):
626
- if input_image is None:
627
- return gr.update(interactive=False), *args, gr.update(visible=True, value="Please upload an image for upscaling.")
628
- return gr.update(interactive=True), *args, gr.update(visible=False)
629
-
630
- with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css, delete_cache=(60, 3600)) as app:
631
- loras_state = gr.State(loras)
632
- selected_indices = gr.State([])
633
-
634
- with gr.Row():
635
- with gr.Column(scale=3):
636
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
637
- with gr.Column(scale=1):
638
- generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
639
-
640
- with gr.Row(elem_id="loaded_loras"):
641
- with gr.Column(scale=1, min_width=25):
642
- randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
643
- with gr.Column(scale=8):
644
- with gr.Row():
645
- with gr.Column(scale=0, min_width=50):
646
- lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
647
- with gr.Column(scale=3, min_width=100):
648
- selected_info_1 = gr.Markdown("Select a LoRA 1")
649
- with gr.Column(scale=5, min_width=50):
650
- lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
651
- with gr.Row():
652
- remove_button_1 = gr.Button("Remove", size="sm")
653
- with gr.Column(scale=8):
654
- with gr.Row():
655
- with gr.Column(scale=0, min_width=50):
656
- lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
657
- with gr.Column(scale=3, min_width=100):
658
- selected_info_2 = gr.Markdown("Select a LoRA 2")
659
- with gr.Column(scale=5, min_width=50):
660
- lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
661
- with gr.Row():
662
- remove_button_2 = gr.Button("Remove", size="sm")
663
-
664
- with gr.Row():
665
- with gr.Column():
666
- with gr.Group():
667
- with gr.Row(elem_id="custom_lora_structure"):
668
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="ginipick/flux-lora-eric-cat", scale=3, min_width=150)
669
- add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
670
- remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
671
- gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
672
- gallery = gr.Gallery(
673
- [(item["image"], item["title"]) for item in loras],
674
- label="Or pick from the LoRA Explorer gallery",
675
- allow_preview=False,
676
- columns=4,
677
- elem_id="gallery"
678
- )
679
- with gr.Column():
680
- progress_bar = gr.Markdown(elem_id="progress", visible=False)
681
- result = gr.Image(label="Generated Image", interactive=False)
682
- with gr.Accordion("History", open=False):
683
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
684
-
685
- with gr.Row():
686
- with gr.Accordion("Advanced Settings", open=False):
687
- with gr.Row():
688
- input_image = gr.Image(label="Input image", type="filepath")
689
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
690
- with gr.Column():
691
- with gr.Row():
692
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
693
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
694
- with gr.Row():
695
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
696
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
697
- with gr.Row():
698
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
699
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
700
-
701
- # 업스케일 관련 UI 추가
702
- with gr.Row():
703
- upscale_button = gr.Button("Upscale", interactive=False)
704
-
705
- with gr.Row():
706
- with gr.Column(scale=4):
707
- upscale_input = gr.Image(label="Input Image for Upscaling", type="pil")
708
- with gr.Column(scale=1):
709
- upscale_steps = gr.Slider(
710
- label="Number of Inference Steps for Upscaling",
711
- minimum=8,
712
- maximum=50,
713
- step=1,
714
- value=28,
715
- )
716
- upscale_factor = gr.Slider(
717
- label="Upscale Factor",
718
- minimum=1,
719
- maximum=4,
720
- step=1,
721
- value=4,
722
- )
723
- controlnet_conditioning_scale = gr.Slider(
724
- label="Controlnet Conditioning Scale",
725
- minimum=0.1,
726
- maximum=1.0,
727
- step=0.05,
728
- value=0.5, # 기본값을 0.5로 낮춤
729
- )
730
- upscale_seed = gr.Slider(
731
- label="Seed for Upscaling",
732
- minimum=0,
733
- maximum=MAX_SEED,
734
- step=1,
735
- value=42,
736
- )
737
- upscale_randomize_seed = gr.Checkbox(label="Randomize seed for Upscaling", value=True)
738
- upscale_error = gr.Markdown(visible=False, value="Please provide an input image for upscaling.")
739
-
740
- with gr.Row():
741
- upscale_result = gr.Image(label="Upscaled Image", type="pil")
742
- upscale_seed_output = gr.Number(label="Seed Used", precision=0)
743
-
744
-
745
- gallery.select(
746
- update_selection,
747
- inputs=[selected_indices, loras_state, width, height],
748
- outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2]
749
- )
750
- remove_button_1.click(
751
- remove_lora_1,
752
- inputs=[selected_indices, loras_state],
753
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
754
- )
755
- remove_button_2.click(
756
- remove_lora_2,
757
- inputs=[selected_indices, loras_state],
758
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
759
- )
760
- randomize_button.click(
761
- randomize_loras,
762
- inputs=[selected_indices, loras_state],
763
- outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
764
- )
765
- add_custom_lora_button.click(
766
- add_custom_lora,
767
- inputs=[custom_lora, selected_indices, loras_state],
768
- outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
769
- )
770
- remove_custom_lora_button.click(
771
- remove_custom_lora,
772
- inputs=[selected_indices, loras_state],
773
- outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
774
- )
775
-
776
- gr.on(
777
- triggers=[generate_button.click, prompt.submit],
778
- fn=run_lora,
779
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
780
- outputs=[result, seed, progress_bar]
781
- ).then(
782
- fn=lambda x, history: update_history(x, history) if x is not None else history,
783
- inputs=[result, history_gallery],
784
- outputs=history_gallery,
785
- )
786
-
787
- upscale_input.upload(
788
- lambda x: gr.update(interactive=x is not None),
789
- inputs=[upscale_input],
790
- outputs=[upscale_button]
791
- )
792
-
793
- upscale_error = gr.Markdown(visible=False, value="")
794
-
795
- upscale_button.click(
796
- infer_upscale,
797
- inputs=[
798
- upscale_seed,
799
- upscale_randomize_seed,
800
- upscale_input,
801
- upscale_steps,
802
- upscale_factor,
803
- controlnet_conditioning_scale,
804
- ],
805
- outputs=[
806
- upscale_result,
807
- upscale_seed_output,
808
- upscale_steps,
809
- upscale_factor,
810
- controlnet_conditioning_scale,
811
- upscale_randomize_seed,
812
- upscale_error
813
- ],
814
-
815
- ).then(
816
- infer_upscale,
817
- inputs=[
818
- upscale_seed,
819
- upscale_randomize_seed,
820
- upscale_input,
821
- upscale_steps,
822
- upscale_factor,
823
- controlnet_conditioning_scale,
824
- ],
825
- outputs=[upscale_result, upscale_seed_output]
826
- )
827
-
828
-
829
- if __name__ == "__main__":
830
- app.queue(max_size=20)
831
- app.launch(debug=True)