barreloflube commited on
Commit
0a8b4a2
β€’
1 Parent(s): 5a59c13

Refactor UI structure and import spaces module

Browse files
Files changed (2) hide show
  1. app.py +674 -6
  2. app2.py +6 -673
app.py CHANGED
@@ -1,11 +1,482 @@
1
- import gradio as gr
2
- import spaces
 
 
 
 
3
 
4
- from src.ui import (
5
- image_tab,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  css = """
10
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
11
  body {
@@ -27,6 +498,11 @@ body {
27
  """
28
 
29
 
 
 
 
 
 
30
  # Main Gradio app
31
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
32
  # Header
@@ -40,14 +516,206 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
40
  # Tabs
41
  with gr.Tabs():
42
  with gr.Tab(label="πŸ–ΌοΈ Image"):
43
- image_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Tab(label="🎡 Audio"):
45
  gr.Label("Coming soon!")
46
  with gr.Tab(label="🎬 Video"):
47
  gr.Label("Coming soon!")
48
  with gr.Tab(label="πŸ“„ Text"):
49
  gr.Label("Coming soon!")
50
-
 
51
  demo.launch(
52
  share=False,
53
  debug=True,
 
1
+ # Testing one file gradio app for zero gpu spaces not working as expected.
2
+ # Check here for the issue:
3
+ import gc
4
+ import json
5
+ import random
6
+ from typing import List, Optional
7
 
8
+ import spaces
9
+ import gradio as gr
10
+ from huggingface_hub import ModelCard
11
+ import torch
12
+ import numpy as np
13
+ from pydantic import BaseModel
14
+ from PIL import Image
15
+ from diffusers import (
16
+ FluxPipeline,
17
+ FluxImg2ImgPipeline,
18
+ FluxInpaintPipeline,
19
+ FluxControlNetPipeline,
20
+ StableDiffusionXLPipeline,
21
+ StableDiffusionXLImg2ImgPipeline,
22
+ StableDiffusionXLInpaintPipeline,
23
+ StableDiffusionXLControlNetPipeline,
24
+ StableDiffusionXLControlNetImg2ImgPipeline,
25
+ StableDiffusionXLControlNetInpaintPipeline,
26
+ AutoPipelineForText2Image,
27
+ AutoPipelineForImage2Image,
28
+ AutoPipelineForInpainting,
29
+ DiffusionPipeline,
30
+ AutoencoderKL,
31
+ FluxControlNetModel,
32
+ FluxMultiControlNetModel,
33
+ ControlNetModel,
34
+ )
35
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
36
+ from huggingface_hub import hf_hub_download
37
+ from transformers import CLIPFeatureExtractor
38
+ from photomaker import FaceAnalysis2
39
+ from diffusers.schedulers import *
40
+ from huggingface_hub import hf_hub_download
41
+ from safetensors.torch import load_file
42
+ from controlnet_aux.processor import Processor
43
+ from photomaker import (
44
+ PhotoMakerStableDiffusionXLPipeline,
45
+ PhotoMakerStableDiffusionXLControlNetPipeline,
46
+ analyze_faces
47
  )
48
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1
49
+
50
+
51
+ # Initialize System
52
+ def load_sd():
53
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ # Models
57
+ models = [
58
+ {
59
+ "repo_id": "black-forest-labs/FLUX.1-dev",
60
+ "loader": "flux",
61
+ "compute_type": torch.bfloat16,
62
+ },
63
+ {
64
+ "repo_id": "SG161222/RealVisXL_V4.0",
65
+ "loader": "xl",
66
+ "compute_type": torch.float16,
67
+ }
68
+ ]
69
+
70
+ for model in models:
71
+ try:
72
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
73
+ model['repo_id'],
74
+ torch_dtype = model['compute_type'],
75
+ safety_checker = None,
76
+ variant = "fp16"
77
+ ).to(device)
78
+ model["pipeline"].enable_model_cpu_offload()
79
+ except:
80
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
81
+ model['repo_id'],
82
+ torch_dtype = model['compute_type'],
83
+ safety_checker = None
84
+ ).to(device)
85
+ model["pipeline"].enable_model_cpu_offload()
86
+
87
+
88
+ # VAE n Refiner
89
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
90
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
91
+ refiner.enable_model_cpu_offload()
92
+
93
+
94
+ # Safety Checker
95
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
96
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
97
+
98
+
99
+ # Controlnets
100
+ controlnet_models = [
101
+ {
102
+ "repo_id": "xinsir/controlnet-depth-sdxl-1.0",
103
+ "name": "depth_xl",
104
+ "layers": ["depth"],
105
+ "loader": "xl",
106
+ "compute_type": torch.float16,
107
+ },
108
+ {
109
+ "repo_id": "xinsir/controlnet-canny-sdxl-1.0",
110
+ "name": "canny_xl",
111
+ "layers": ["canny"],
112
+ "loader": "xl",
113
+ "compute_type": torch.float16,
114
+ },
115
+ {
116
+ "repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
117
+ "name": "openpose_xl",
118
+ "layers": ["pose"],
119
+ "loader": "xl",
120
+ "compute_type": torch.float16,
121
+ },
122
+ {
123
+ "repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
124
+ "name": "scribble_xl",
125
+ "layers": ["scribble"],
126
+ "loader": "xl",
127
+ "compute_type": torch.float16,
128
+ },
129
+ {
130
+ "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
131
+ "name": "flux1_union_pro",
132
+ "layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"],
133
+ "loader": "flux-multi",
134
+ "compute_type": torch.bfloat16,
135
+ }
136
+ ]
137
+
138
+ for controlnet in controlnet_models:
139
+ if controlnet["loader"] == "xl":
140
+ controlnet["controlnet"] = ControlNetModel.from_pretrained(
141
+ controlnet["repo_id"],
142
+ torch_dtype = controlnet['compute_type']
143
+ ).to(device)
144
+ elif controlnet["loader"] == "flux-multi":
145
+ controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
146
+ controlnet["repo_id"],
147
+ torch_dtype = controlnet['compute_type']
148
+ ).to(device)])
149
+ #TODO: Add support for flux only controlnet
150
+
151
+
152
+ # Face Detection (for PhotoMaker)
153
+ face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
154
+ face_detector.prepare(ctx_id=0, det_size=(640, 640))
155
+
156
+
157
+ # PhotoMaker V2 (for SDXL only)
158
+ photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")
159
+
160
+ return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt
161
+
162
+
163
+ device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()
164
+
165
+
166
+ # Models
167
+ class ControlNetReq(BaseModel):
168
+ controlnets: List[str] # ["canny", "tile", "depth"]
169
+ control_images: List[Image.Image]
170
+ controlnet_conditioning_scale: List[float]
171
+
172
+ class Config:
173
+ arbitrary_types_allowed=True
174
+
175
+
176
+ class SDReq(BaseModel):
177
+ model: str = ""
178
+ prompt: str = ""
179
+ negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev"
180
+ fast_generation: Optional[bool] = True
181
+ loras: Optional[list] = []
182
+ embeddings: Optional[list] = []
183
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
184
+ scheduler: Optional[str] = "euler_fl"
185
+ height: int = 1024
186
+ width: int = 1024
187
+ num_images_per_prompt: int = 1
188
+ num_inference_steps: int = 8
189
+ guidance_scale: float = 3.5
190
+ seed: Optional[int] = 0
191
+ refiner: bool = False
192
+ vae: bool = True
193
+ controlnet_config: Optional[ControlNetReq] = None
194
+ photomaker_images: Optional[List[Image.Image]] = None
195
+
196
+ class Config:
197
+ arbitrary_types_allowed=True
198
+
199
+
200
+ class SDImg2ImgReq(SDReq):
201
+ image: Image.Image
202
+ strength: float = 1.0
203
+
204
+ class Config:
205
+ arbitrary_types_allowed=True
206
+
207
+
208
+ class SDInpaintReq(SDImg2ImgReq):
209
+ mask_image: Image.Image
210
+
211
+ class Config:
212
+ arbitrary_types_allowed=True
213
+
214
+
215
+ # Helper functions
216
+ def get_controlnet(controlnet_config: ControlNetReq):
217
+ control_mode = []
218
+ controlnet = []
219
+
220
+ for m in controlnet_models:
221
+ for c in controlnet_config.controlnets:
222
+ if c in m["layers"]:
223
+ control_mode.append(m["layers"].index(c))
224
+ controlnet.append(m["controlnet"])
225
+
226
+ return controlnet, control_mode
227
+
228
+
229
+ def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq):
230
+ for m in models:
231
+ if m["repo_id"] == request.model:
232
+ pipeline = m['pipeline']
233
+ controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None)
234
+
235
+ pipe_args = {
236
+ "pipeline": pipeline,
237
+ "control_mode": control_mode,
238
+ }
239
+ if request.controlnet_config:
240
+ pipe_args["controlnet"] = controlnet
241
+
242
+ if not request.photomaker_images:
243
+ if isinstance(request, SDReq):
244
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
245
+ elif isinstance(request, SDImg2ImgReq):
246
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
247
+ elif isinstance(request, SDInpaintReq):
248
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
249
+ else:
250
+ raise ValueError(f"Unknown request type: {type(request)}")
251
+ elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
252
+ if request.controlnet_config:
253
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args)
254
+ else:
255
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args)
256
+ else:
257
+ raise ValueError(f"Invalid request type: {type(request)}")
258
+
259
+ return pipe_args
260
 
261
 
262
+ def load_scheduler(pipeline, scheduler):
263
+ schedulers = {
264
+ "dpmpp_2m": (DPMSolverMultistepScheduler, {}),
265
+ "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
266
+ "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
267
+ "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
268
+ "dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
269
+ "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
270
+ "dpm2": (KDPM2DiscreteScheduler, {}),
271
+ "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
272
+ "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
273
+ "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
274
+ "euler": (EulerDiscreteScheduler, {}),
275
+ "euler_a": (EulerAncestralDiscreteScheduler, {}),
276
+ "heun": (HeunDiscreteScheduler, {}),
277
+ "lms": (LMSDiscreteScheduler, {}),
278
+ "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
279
+ "deis": (DEISMultistepScheduler, {}),
280
+ "unipc": (UniPCMultistepScheduler, {}),
281
+ "fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
282
+ }
283
+ scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
284
+
285
+ if scheduler_class is not None:
286
+ scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
287
+ else:
288
+ raise ValueError(f"Unknown scheduler: {scheduler}")
289
+
290
+ return scheduler
291
+
292
+
293
+ def load_loras(pipeline, loras, fast_generation):
294
+ for i, lora in enumerate(loras):
295
+ pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
296
+ adapter_names = [f"lora_{i}" for i in range(len(loras))]
297
+ adapter_weights = [lora['weight'] for lora in loras]
298
+
299
+ if fast_generation:
300
+ hyper_lora = hf_hub_download(
301
+ "ByteDance/Hyper-SD",
302
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors"
303
+ )
304
+ hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0
305
+ pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora")
306
+ adapter_names.append("hyper_lora")
307
+ adapter_weights.append(hyper_weight)
308
+
309
+ pipeline.set_adapters(adapter_names, adapter_weights)
310
+
311
+
312
+ def load_xl_embeddings(pipeline, embeddings):
313
+ for embedding in embeddings:
314
+ state_dict = load_file(hf_hub_download(embedding['repo_id']))
315
+ pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
316
+ pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
317
+
318
+
319
+ def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
320
+ for image in images:
321
+ if resize_mode == "resize_only":
322
+ image = image.resize((width, height))
323
+ elif resize_mode == "crop_and_resize":
324
+ image = image.crop((0, 0, width, height))
325
+ elif resize_mode == "resize_and_fill":
326
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
327
+
328
+ return images
329
+
330
+
331
+ def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
332
+ response_images = []
333
+ control_images = resize_images(control_images, height, width, resize_mode)
334
+ for controlnet, image in zip(controlnets, control_images):
335
+ if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl":
336
+ processor = Processor('canny')
337
+ elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl":
338
+ processor = Processor('depth_midas')
339
+ elif controlnet == "pose" or controlnet == "pose_fl":
340
+ processor = Processor('openpose_full')
341
+ elif controlnet == "scribble":
342
+ processor = Processor('scribble')
343
+ else:
344
+ raise ValueError(f"Invalid Controlnet: {controlnet}")
345
+
346
+ response_images.append(processor(image, to_pil=True))
347
+
348
+ return response_images
349
+
350
+
351
+ def check_image_safety(images: List[Image.Image]):
352
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
353
+ has_nsfw_concepts = safety_checker(
354
+ images=[images],
355
+ clip_input=safety_checker_input.pixel_values.to("cuda"),
356
+ )
357
+
358
+ return has_nsfw_concepts[1]
359
+
360
+
361
+ def get_prompt_attention(pipeline, prompt, negative_prompt):
362
+ if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)):
363
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
364
+ return prompt_embeds, None, pooled_prompt_embeds, None
365
+ elif isinstance(pipeline, StableDiffusionXLPipeline):
366
+ prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
367
+ return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
368
+ else:
369
+ raise ValueError(f"Invalid pipeline type: {type(pipeline)}")
370
+
371
+
372
+ def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str):
373
+ image_input_ids = []
374
+ image_id_embeds = []
375
+ photomaker_images = resize_images(photomaker_images, height, width, resize_mode)
376
+
377
+ for image in photomaker_images:
378
+ image_input_ids.append(img)
379
+ img = np.array(image)[:, :, ::-1]
380
+ faces = analyze_faces(face_detector, image)
381
+ if len(faces) > 0:
382
+ image_id_embeds.append(torch.from_numpy(faces[0]['embeddings']))
383
+ else:
384
+ raise ValueError("No face detected in the image")
385
+
386
+ return image_input_ids, image_id_embeds
387
+
388
+
389
+ def cleanup(pipeline, loras = None, embeddings = None):
390
+ if loras:
391
+ pipeline.disable_lora()
392
+ pipeline.unload_lora_weights()
393
+ if embeddings:
394
+ pipeline.unload_textual_inversion()
395
+ gc.collect()
396
+ torch.cuda.empty_cache()
397
+
398
+
399
+ # Gen function
400
+ @spaces.GPU
401
+ def gen_img(
402
+ request: SDReq | SDImg2ImgReq | SDInpaintReq
403
+ ):
404
+ pipeline_args = get_pipe(request)
405
+ pipeline = pipeline_args['pipeline']
406
+ try:
407
+ pipeline.scheduler = load_scheduler(pipeline, request.scheduler)
408
+
409
+ load_loras(pipeline, request.loras, request.fast_generation)
410
+ load_xl_embeddings(pipeline, request.embeddings)
411
+
412
+ control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None
413
+ photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None)
414
+
415
+ positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
416
+
417
+ # Common args
418
+ args = {
419
+ 'prompt_embeds': positive_prompt_embeds,
420
+ 'pooled_prompt_embeds': positive_prompt_pooled,
421
+ 'height': request.height,
422
+ 'width': request.width,
423
+ 'num_images_per_prompt': request.num_images_per_prompt,
424
+ 'num_inference_steps': request.num_inference_steps,
425
+ 'guidance_scale': request.guidance_scale,
426
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
427
+ }
428
+
429
+ if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
430
+ StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])):
431
+ args['clip_skip'] = request.clip_skip
432
+ args['negative_prompt_embeds'] = negative_prompt_embeds
433
+ args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
434
+
435
+ if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
436
+ args['control_mode'] = pipeline_args['control_mode']
437
+ args['control_image'] = control_images
438
+ args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
439
+
440
+ if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
441
+ args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
442
+
443
+ if isinstance(request, SDReq):
444
+ args['image'] = control_images
445
+ elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)):
446
+ args['control_image'] = control_images
447
+
448
+ if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
449
+ args['input_id_images'] = photomaker_images
450
+ args['input_id_embeds'] = photomaker_id_embeds
451
+ args['start_merge_step'] = 10
452
+
453
+ if isinstance(request, SDImg2ImgReq):
454
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
455
+ args['strength'] = request.strength
456
+ elif isinstance(request, SDInpaintReq):
457
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
458
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)
459
+ args['strength'] = request.strength
460
+
461
+ images = pipeline(**args).images
462
+
463
+ if request.refiner:
464
+ images = refiner(
465
+ prompt=request.prompt,
466
+ num_inference_steps=40,
467
+ denoising_start=0.7,
468
+ image=images.images
469
+ ).images
470
+
471
+ cleanup(pipeline, request.loras, request.embeddings)
472
+
473
+ return images
474
+ except Exception as e:
475
+ cleanup(pipeline, request.loras, request.embeddings)
476
+ raise ValueError(f"Error generating image: {e}") from e
477
+
478
+
479
+ # CSS
480
  css = """
481
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
482
  body {
 
498
  """
499
 
500
 
501
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
502
+ with open("data/images/loras/flux.json", "r") as f:
503
+ loras = json.load(f)
504
+
505
+
506
  # Main Gradio app
507
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
508
  # Header
 
516
  # Tabs
517
  with gr.Tabs():
518
  with gr.Tab(label="πŸ–ΌοΈ Image"):
519
+ with gr.Tabs():
520
+ with gr.Tab("Flux"):
521
+ """
522
+ Create the image tab for Generative Image Generation Models
523
+
524
+ Args:
525
+ models: list
526
+ A list containing the models repository paths
527
+ gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
528
+ A list of dictionaries containing the title and component for the custom gradio component
529
+ Example:
530
+ def gr_comp():
531
+ gr.Label("Hello World")
532
+
533
+ [
534
+ {
535
+ 'title': "Title",
536
+ 'component': gr_comp()
537
+ }
538
+ ]
539
+ loras: list
540
+ A list of dictionaries containing the image and title for the Loras Gallery
541
+ Generally a loaded json file from the data folder
542
+
543
+ """
544
+ def process_gaps(gaps: List[dict]):
545
+ for gap in gaps:
546
+ with gr.Accordion(gap['title']):
547
+ gap['component']
548
+
549
+
550
+ with gr.Row():
551
+ with gr.Column():
552
+ with gr.Group() as image_options:
553
+ model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
554
+ prompt = gr.Textbox(lines=5, label="Prompt")
555
+ negative_prompt = gr.Textbox(label="Negative Prompt")
556
+ fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
557
+
558
+
559
+ with gr.Accordion("Loras", open=True): # Lora Gallery
560
+ lora_gallery = gr.Gallery(
561
+ label="Gallery",
562
+ value=[(lora['image'], lora['title']) for lora in loras],
563
+ allow_preview=False,
564
+ columns=[3],
565
+ type="pil"
566
+ )
567
+
568
+ with gr.Group():
569
+ with gr.Column():
570
+ with gr.Row():
571
+ custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
572
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
573
+
574
+ custom_lora_info = gr.HTML(visible=False)
575
+ add_lora = gr.Button(value="Add LoRA")
576
+
577
+ enabled_loras = gr.State(value=[])
578
+ with gr.Group():
579
+ with gr.Row():
580
+ for i in range(6): # only support max 6 loras due to inference time
581
+ with gr.Column():
582
+ with gr.Column(scale=2):
583
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
584
+ with gr.Column():
585
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
586
+
587
+
588
+ with gr.Accordion("Embeddings", open=False): # Embeddings
589
+ gr.Label("To be implemented")
590
+
591
+
592
+ with gr.Accordion("Image Options"): # Image Options
593
+ with gr.Tabs():
594
+ image_options = {
595
+ "img2img": "Upload Image",
596
+ "inpaint": "Upload Image",
597
+ "canny": "Upload Image",
598
+ "pose": "Upload Image",
599
+ "depth": "Upload Image",
600
+ }
601
+
602
+ for image_option, label in image_options.items():
603
+ with gr.Tab(image_option):
604
+ if not image_option in ['inpaint', 'scribble']:
605
+ globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
606
+ elif image_option in ['inpaint', 'scribble']:
607
+ globals()[f"{image_option}_image"] = gr.ImageEditor(
608
+ label=label,
609
+ image_mode='RGB',
610
+ layers=False,
611
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
612
+ interactive=True,
613
+ type="pil",
614
+ )
615
+
616
+ # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
617
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
618
+
619
+ resize_mode = gr.Radio(
620
+ label="Resize Mode",
621
+ choices=["crop and resize", "resize only", "resize and fill"],
622
+ value="resize and fill",
623
+ interactive=True
624
+ )
625
+
626
+
627
+ with gr.Column():
628
+ with gr.Group():
629
+ output_images = gr.Gallery(
630
+ label="Output Images",
631
+ value=[],
632
+ allow_preview=True,
633
+ type="pil",
634
+ interactive=False,
635
+ )
636
+ generate_images = gr.Button(value="Generate Images", variant="primary")
637
+
638
+ with gr.Accordion("Advance Settings", open=True):
639
+ with gr.Row():
640
+ scheduler = gr.Dropdown(
641
+ label="Scheduler",
642
+ choices = [
643
+ "fm_euler"
644
+ ],
645
+ value="fm_euler",
646
+ interactive=True
647
+ )
648
+
649
+ with gr.Row():
650
+ for column in range(2):
651
+ with gr.Column():
652
+ options = [
653
+ ("Height", "image_height", 64, 1024, 64, 1024, True),
654
+ ("Width", "image_width", 64, 1024, 64, 1024, True),
655
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
656
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
657
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
658
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
659
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
660
+ ]
661
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
662
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
663
+
664
+ with gr.Row():
665
+ refiner = gr.Checkbox(
666
+ label="Refiner πŸ§ͺ",
667
+ value=False,
668
+ )
669
+ vae = gr.Checkbox(
670
+ label="VAE",
671
+ value=True,
672
+ )
673
+
674
+
675
+ # Events
676
+ # Base Options
677
+ fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
678
+
679
+
680
+ # Lora Gallery
681
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
682
+ custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
683
+ add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
684
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
685
+
686
+ for i in range(6):
687
+ globals()[f"lora_remove_{i}"].click(
688
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
689
+ [enabled_loras],
690
+ [enabled_loras]
691
+ )
692
+
693
+
694
+ # Generate Image
695
+ generate_images.click(
696
+ generate_image, # type: ignore
697
+ [
698
+ model, prompt, negative_prompt, fast_generation, enabled_loras,
699
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
700
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
701
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
702
+ resize_mode,
703
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
704
+ image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
705
+ refiner, vae
706
+ ],
707
+ [output_images]
708
+ )
709
+ with gr.Tab("SDXL"):
710
+ gr.Label("To be implemented")
711
  with gr.Tab(label="🎡 Audio"):
712
  gr.Label("Coming soon!")
713
  with gr.Tab(label="🎬 Video"):
714
  gr.Label("Coming soon!")
715
  with gr.Tab(label="πŸ“„ Text"):
716
  gr.Label("Coming soon!")
717
+
718
+
719
  demo.launch(
720
  share=False,
721
  debug=True,
app2.py CHANGED
@@ -1,481 +1,11 @@
1
- # Testing one file gradio app for zero gpu spaces not working as expected.
2
- # Check here for the issue:
3
- import gc
4
- import json
5
- import random
6
- from typing import List, Optional
7
-
8
- import spaces
9
  import gradio as gr
10
- from huggingface_hub import ModelCard
11
- import torch
12
- import numpy as np
13
- from pydantic import BaseModel
14
- from PIL import Image
15
- from diffusers import (
16
- FluxPipeline,
17
- FluxImg2ImgPipeline,
18
- FluxInpaintPipeline,
19
- FluxControlNetPipeline,
20
- StableDiffusionXLPipeline,
21
- StableDiffusionXLImg2ImgPipeline,
22
- StableDiffusionXLInpaintPipeline,
23
- StableDiffusionXLControlNetPipeline,
24
- StableDiffusionXLControlNetImg2ImgPipeline,
25
- StableDiffusionXLControlNetInpaintPipeline,
26
- AutoPipelineForText2Image,
27
- AutoPipelineForImage2Image,
28
- AutoPipelineForInpainting,
29
- DiffusionPipeline,
30
- AutoencoderKL,
31
- FluxControlNetModel,
32
- FluxMultiControlNetModel,
33
- ControlNetModel,
34
- )
35
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
36
- from huggingface_hub import hf_hub_download
37
- from transformers import CLIPFeatureExtractor
38
- from photomaker import FaceAnalysis2
39
- from diffusers.schedulers import *
40
- from huggingface_hub import hf_hub_download
41
- from safetensors.torch import load_file
42
- from controlnet_aux.processor import Processor
43
- from photomaker import (
44
- PhotoMakerStableDiffusionXLPipeline,
45
- PhotoMakerStableDiffusionXLControlNetPipeline,
46
- analyze_faces
47
- )
48
- from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1
49
-
50
-
51
- # Initialize System
52
- def load_sd():
53
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
-
56
- # Models
57
- models = [
58
- {
59
- "repo_id": "black-forest-labs/FLUX.1-dev",
60
- "loader": "flux",
61
- "compute_type": torch.bfloat16,
62
- },
63
- {
64
- "repo_id": "SG161222/RealVisXL_V4.0",
65
- "loader": "xl",
66
- "compute_type": torch.float16,
67
- }
68
- ]
69
-
70
- for model in models:
71
- try:
72
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
73
- model['repo_id'],
74
- torch_dtype = model['compute_type'],
75
- safety_checker = None,
76
- variant = "fp16"
77
- ).to(device)
78
- model["pipeline"].enable_model_cpu_offload()
79
- except:
80
- model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
81
- model['repo_id'],
82
- torch_dtype = model['compute_type'],
83
- safety_checker = None
84
- ).to(device)
85
- model["pipeline"].enable_model_cpu_offload()
86
-
87
-
88
- # VAE n Refiner
89
- sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
90
- refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
91
- refiner.enable_model_cpu_offload()
92
-
93
-
94
- # Safety Checker
95
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
96
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
97
-
98
-
99
- # Controlnets
100
- controlnet_models = [
101
- {
102
- "repo_id": "xinsir/controlnet-depth-sdxl-1.0",
103
- "name": "depth_xl",
104
- "layers": ["depth"],
105
- "loader": "xl",
106
- "compute_type": torch.float16,
107
- },
108
- {
109
- "repo_id": "xinsir/controlnet-canny-sdxl-1.0",
110
- "name": "canny_xl",
111
- "layers": ["canny"],
112
- "loader": "xl",
113
- "compute_type": torch.float16,
114
- },
115
- {
116
- "repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
117
- "name": "openpose_xl",
118
- "layers": ["pose"],
119
- "loader": "xl",
120
- "compute_type": torch.float16,
121
- },
122
- {
123
- "repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
124
- "name": "scribble_xl",
125
- "layers": ["scribble"],
126
- "loader": "xl",
127
- "compute_type": torch.float16,
128
- },
129
- {
130
- "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
131
- "name": "flux1_union_pro",
132
- "layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"],
133
- "loader": "flux-multi",
134
- "compute_type": torch.bfloat16,
135
- }
136
- ]
137
-
138
- for controlnet in controlnet_models:
139
- if controlnet["loader"] == "xl":
140
- controlnet["controlnet"] = ControlNetModel.from_pretrained(
141
- controlnet["repo_id"],
142
- torch_dtype = controlnet['compute_type']
143
- ).to(device)
144
- elif controlnet["loader"] == "flux-multi":
145
- controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
146
- controlnet["repo_id"],
147
- torch_dtype = controlnet['compute_type']
148
- ).to(device)])
149
- #TODO: Add support for flux only controlnet
150
-
151
-
152
- # Face Detection (for PhotoMaker)
153
- face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
154
- face_detector.prepare(ctx_id=0, det_size=(640, 640))
155
-
156
-
157
- # PhotoMaker V2 (for SDXL only)
158
- photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")
159
-
160
- return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt
161
-
162
-
163
- device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()
164
-
165
-
166
- # Models
167
- class ControlNetReq(BaseModel):
168
- controlnets: List[str] # ["canny", "tile", "depth"]
169
- control_images: List[Image.Image]
170
- controlnet_conditioning_scale: List[float]
171
-
172
- class Config:
173
- arbitrary_types_allowed=True
174
-
175
-
176
- class SDReq(BaseModel):
177
- model: str = ""
178
- prompt: str = ""
179
- negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev"
180
- fast_generation: Optional[bool] = True
181
- loras: Optional[list] = []
182
- embeddings: Optional[list] = []
183
- resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
184
- scheduler: Optional[str] = "euler_fl"
185
- height: int = 1024
186
- width: int = 1024
187
- num_images_per_prompt: int = 1
188
- num_inference_steps: int = 8
189
- guidance_scale: float = 3.5
190
- seed: Optional[int] = 0
191
- refiner: bool = False
192
- vae: bool = True
193
- controlnet_config: Optional[ControlNetReq] = None
194
- photomaker_images: Optional[List[Image.Image]] = None
195
-
196
- class Config:
197
- arbitrary_types_allowed=True
198
-
199
-
200
- class SDImg2ImgReq(SDReq):
201
- image: Image.Image
202
- strength: float = 1.0
203
-
204
- class Config:
205
- arbitrary_types_allowed=True
206
-
207
-
208
- class SDInpaintReq(SDImg2ImgReq):
209
- mask_image: Image.Image
210
-
211
- class Config:
212
- arbitrary_types_allowed=True
213
-
214
-
215
- # Helper functions
216
- def get_controlnet(controlnet_config: ControlNetReq):
217
- control_mode = []
218
- controlnet = []
219
-
220
- for m in controlnet_models:
221
- for c in controlnet_config.controlnets:
222
- if c in m["layers"]:
223
- control_mode.append(m["layers"].index(c))
224
- controlnet.append(m["controlnet"])
225
-
226
- return controlnet, control_mode
227
-
228
-
229
- def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq):
230
- for m in models:
231
- if m["repo_id"] == request.model:
232
- pipeline = m['pipeline']
233
- controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None)
234
-
235
- pipe_args = {
236
- "pipeline": pipeline,
237
- "control_mode": control_mode,
238
- }
239
- if request.controlnet_config:
240
- pipe_args["controlnet"] = controlnet
241
-
242
- if not request.photomaker_images:
243
- if isinstance(request, SDReq):
244
- pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
245
- elif isinstance(request, SDImg2ImgReq):
246
- pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
247
- elif isinstance(request, SDInpaintReq):
248
- pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
249
- else:
250
- raise ValueError(f"Unknown request type: {type(request)}")
251
- elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
252
- if request.controlnet_config:
253
- pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args)
254
- else:
255
- pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args)
256
- else:
257
- raise ValueError(f"Invalid request type: {type(request)}")
258
-
259
- return pipe_args
260
-
261
-
262
- def load_scheduler(pipeline, scheduler):
263
- schedulers = {
264
- "dpmpp_2m": (DPMSolverMultistepScheduler, {}),
265
- "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
266
- "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
267
- "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
268
- "dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
269
- "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
270
- "dpm2": (KDPM2DiscreteScheduler, {}),
271
- "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
272
- "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
273
- "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
274
- "euler": (EulerDiscreteScheduler, {}),
275
- "euler_a": (EulerAncestralDiscreteScheduler, {}),
276
- "heun": (HeunDiscreteScheduler, {}),
277
- "lms": (LMSDiscreteScheduler, {}),
278
- "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
279
- "deis": (DEISMultistepScheduler, {}),
280
- "unipc": (UniPCMultistepScheduler, {}),
281
- "fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
282
- }
283
- scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
284
-
285
- if scheduler_class is not None:
286
- scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
287
- else:
288
- raise ValueError(f"Unknown scheduler: {scheduler}")
289
-
290
- return scheduler
291
-
292
-
293
- def load_loras(pipeline, loras, fast_generation):
294
- for i, lora in enumerate(loras):
295
- pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
296
- adapter_names = [f"lora_{i}" for i in range(len(loras))]
297
- adapter_weights = [lora['weight'] for lora in loras]
298
-
299
- if fast_generation:
300
- hyper_lora = hf_hub_download(
301
- "ByteDance/Hyper-SD",
302
- "Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors"
303
- )
304
- hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0
305
- pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora")
306
- adapter_names.append("hyper_lora")
307
- adapter_weights.append(hyper_weight)
308
-
309
- pipeline.set_adapters(adapter_names, adapter_weights)
310
-
311
-
312
- def load_xl_embeddings(pipeline, embeddings):
313
- for embedding in embeddings:
314
- state_dict = load_file(hf_hub_download(embedding['repo_id']))
315
- pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
316
- pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
317
-
318
-
319
- def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
320
- for image in images:
321
- if resize_mode == "resize_only":
322
- image = image.resize((width, height))
323
- elif resize_mode == "crop_and_resize":
324
- image = image.crop((0, 0, width, height))
325
- elif resize_mode == "resize_and_fill":
326
- image = image.resize((width, height), Image.Resampling.LANCZOS)
327
-
328
- return images
329
-
330
-
331
- def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
332
- response_images = []
333
- control_images = resize_images(control_images, height, width, resize_mode)
334
- for controlnet, image in zip(controlnets, control_images):
335
- if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl":
336
- processor = Processor('canny')
337
- elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl":
338
- processor = Processor('depth_midas')
339
- elif controlnet == "pose" or controlnet == "pose_fl":
340
- processor = Processor('openpose_full')
341
- elif controlnet == "scribble":
342
- processor = Processor('scribble')
343
- else:
344
- raise ValueError(f"Invalid Controlnet: {controlnet}")
345
-
346
- response_images.append(processor(image, to_pil=True))
347
-
348
- return response_images
349
-
350
-
351
- def check_image_safety(images: List[Image.Image]):
352
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
353
- has_nsfw_concepts = safety_checker(
354
- images=[images],
355
- clip_input=safety_checker_input.pixel_values.to("cuda"),
356
- )
357
-
358
- return has_nsfw_concepts[1]
359
-
360
-
361
- def get_prompt_attention(pipeline, prompt, negative_prompt):
362
- if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)):
363
- prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
364
- return prompt_embeds, None, pooled_prompt_embeds, None
365
- elif isinstance(pipeline, StableDiffusionXLPipeline):
366
- prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
367
- return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
368
- else:
369
- raise ValueError(f"Invalid pipeline type: {type(pipeline)}")
370
-
371
-
372
- def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str):
373
- image_input_ids = []
374
- image_id_embeds = []
375
- photomaker_images = resize_images(photomaker_images, height, width, resize_mode)
376
-
377
- for image in photomaker_images:
378
- image_input_ids.append(img)
379
- img = np.array(image)[:, :, ::-1]
380
- faces = analyze_faces(face_detector, image)
381
- if len(faces) > 0:
382
- image_id_embeds.append(torch.from_numpy(faces[0]['embeddings']))
383
- else:
384
- raise ValueError("No face detected in the image")
385
-
386
- return image_input_ids, image_id_embeds
387
-
388
-
389
- def cleanup(pipeline, loras = None, embeddings = None):
390
- if loras:
391
- pipeline.disable_lora()
392
- pipeline.unload_lora_weights()
393
- if embeddings:
394
- pipeline.unload_textual_inversion()
395
- gc.collect()
396
- torch.cuda.empty_cache()
397
-
398
 
399
- # Gen function
400
- def gen_img(
401
- request: SDReq | SDImg2ImgReq | SDInpaintReq
402
- ):
403
- pipeline_args = get_pipe(request)
404
- pipeline = pipeline_args['pipeline']
405
- try:
406
- pipeline.scheduler = load_scheduler(pipeline, request.scheduler)
407
-
408
- load_loras(pipeline, request.loras, request.fast_generation)
409
- load_xl_embeddings(pipeline, request.embeddings)
410
-
411
- control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None
412
- photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None)
413
-
414
- positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
415
-
416
- # Common args
417
- args = {
418
- 'prompt_embeds': positive_prompt_embeds,
419
- 'pooled_prompt_embeds': positive_prompt_pooled,
420
- 'height': request.height,
421
- 'width': request.width,
422
- 'num_images_per_prompt': request.num_images_per_prompt,
423
- 'num_inference_steps': request.num_inference_steps,
424
- 'guidance_scale': request.guidance_scale,
425
- 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
426
- }
427
-
428
- if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
429
- StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])):
430
- args['clip_skip'] = request.clip_skip
431
- args['negative_prompt_embeds'] = negative_prompt_embeds
432
- args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
433
-
434
- if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
435
- args['control_mode'] = pipeline_args['control_mode']
436
- args['control_image'] = control_images
437
- args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
438
-
439
- if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
440
- args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
441
-
442
- if isinstance(request, SDReq):
443
- args['image'] = control_images
444
- elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)):
445
- args['control_image'] = control_images
446
-
447
- if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
448
- args['input_id_images'] = photomaker_images
449
- args['input_id_embeds'] = photomaker_id_embeds
450
- args['start_merge_step'] = 10
451
-
452
- if isinstance(request, SDImg2ImgReq):
453
- args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
454
- args['strength'] = request.strength
455
- elif isinstance(request, SDInpaintReq):
456
- args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
457
- args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)
458
- args['strength'] = request.strength
459
-
460
- images = pipeline(**args).images
461
-
462
- if request.refiner:
463
- images = refiner(
464
- prompt=request.prompt,
465
- num_inference_steps=40,
466
- denoising_start=0.7,
467
- image=images.images
468
- ).images
469
-
470
- cleanup(pipeline, request.loras, request.embeddings)
471
-
472
- return images
473
- except Exception as e:
474
- cleanup(pipeline, request.loras, request.embeddings)
475
- raise ValueError(f"Error generating image: {e}") from e
476
 
477
 
478
- # CSS
479
  css = """
480
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
481
  body {
@@ -497,11 +27,6 @@ body {
497
  """
498
 
499
 
500
- flux_models = ["black-forest-labs/FLUX.1-dev"]
501
- with open("data/images/loras/flux.json", "r") as f:
502
- loras = json.load(f)
503
-
504
-
505
  # Main Gradio app
506
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
507
  # Header
@@ -515,206 +40,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
515
  # Tabs
516
  with gr.Tabs():
517
  with gr.Tab(label="πŸ–ΌοΈ Image"):
518
- with gr.Tabs():
519
- with gr.Tab("Flux"):
520
- """
521
- Create the image tab for Generative Image Generation Models
522
-
523
- Args:
524
- models: list
525
- A list containing the models repository paths
526
- gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
527
- A list of dictionaries containing the title and component for the custom gradio component
528
- Example:
529
- def gr_comp():
530
- gr.Label("Hello World")
531
-
532
- [
533
- {
534
- 'title': "Title",
535
- 'component': gr_comp()
536
- }
537
- ]
538
- loras: list
539
- A list of dictionaries containing the image and title for the Loras Gallery
540
- Generally a loaded json file from the data folder
541
-
542
- """
543
- def process_gaps(gaps: List[dict]):
544
- for gap in gaps:
545
- with gr.Accordion(gap['title']):
546
- gap['component']
547
-
548
-
549
- with gr.Row():
550
- with gr.Column():
551
- with gr.Group() as image_options:
552
- model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
553
- prompt = gr.Textbox(lines=5, label="Prompt")
554
- negative_prompt = gr.Textbox(label="Negative Prompt")
555
- fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
556
-
557
-
558
- with gr.Accordion("Loras", open=True): # Lora Gallery
559
- lora_gallery = gr.Gallery(
560
- label="Gallery",
561
- value=[(lora['image'], lora['title']) for lora in loras],
562
- allow_preview=False,
563
- columns=[3],
564
- type="pil"
565
- )
566
-
567
- with gr.Group():
568
- with gr.Column():
569
- with gr.Row():
570
- custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
571
- selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
572
-
573
- custom_lora_info = gr.HTML(visible=False)
574
- add_lora = gr.Button(value="Add LoRA")
575
-
576
- enabled_loras = gr.State(value=[])
577
- with gr.Group():
578
- with gr.Row():
579
- for i in range(6): # only support max 6 loras due to inference time
580
- with gr.Column():
581
- with gr.Column(scale=2):
582
- globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
583
- with gr.Column():
584
- globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
585
-
586
-
587
- with gr.Accordion("Embeddings", open=False): # Embeddings
588
- gr.Label("To be implemented")
589
-
590
-
591
- with gr.Accordion("Image Options"): # Image Options
592
- with gr.Tabs():
593
- image_options = {
594
- "img2img": "Upload Image",
595
- "inpaint": "Upload Image",
596
- "canny": "Upload Image",
597
- "pose": "Upload Image",
598
- "depth": "Upload Image",
599
- }
600
-
601
- for image_option, label in image_options.items():
602
- with gr.Tab(image_option):
603
- if not image_option in ['inpaint', 'scribble']:
604
- globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
605
- elif image_option in ['inpaint', 'scribble']:
606
- globals()[f"{image_option}_image"] = gr.ImageEditor(
607
- label=label,
608
- image_mode='RGB',
609
- layers=False,
610
- brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
611
- interactive=True,
612
- type="pil",
613
- )
614
-
615
- # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
616
- globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
617
-
618
- resize_mode = gr.Radio(
619
- label="Resize Mode",
620
- choices=["crop and resize", "resize only", "resize and fill"],
621
- value="resize and fill",
622
- interactive=True
623
- )
624
-
625
-
626
- with gr.Column():
627
- with gr.Group():
628
- output_images = gr.Gallery(
629
- label="Output Images",
630
- value=[],
631
- allow_preview=True,
632
- type="pil",
633
- interactive=False,
634
- )
635
- generate_images = gr.Button(value="Generate Images", variant="primary")
636
-
637
- with gr.Accordion("Advance Settings", open=True):
638
- with gr.Row():
639
- scheduler = gr.Dropdown(
640
- label="Scheduler",
641
- choices = [
642
- "fm_euler"
643
- ],
644
- value="fm_euler",
645
- interactive=True
646
- )
647
-
648
- with gr.Row():
649
- for column in range(2):
650
- with gr.Column():
651
- options = [
652
- ("Height", "image_height", 64, 1024, 64, 1024, True),
653
- ("Width", "image_width", 64, 1024, 64, 1024, True),
654
- ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
655
- ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
656
- ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
657
- ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
658
- ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
659
- ]
660
- for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
661
- globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
662
-
663
- with gr.Row():
664
- refiner = gr.Checkbox(
665
- label="Refiner πŸ§ͺ",
666
- value=False,
667
- )
668
- vae = gr.Checkbox(
669
- label="VAE",
670
- value=True,
671
- )
672
-
673
-
674
- # Events
675
- # Base Options
676
- fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
677
-
678
-
679
- # Lora Gallery
680
- lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
681
- custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
682
- add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
683
- enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
684
-
685
- for i in range(6):
686
- globals()[f"lora_remove_{i}"].click(
687
- lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
688
- [enabled_loras],
689
- [enabled_loras]
690
- )
691
-
692
-
693
- # Generate Image
694
- generate_images.click(
695
- generate_image, # type: ignore
696
- [
697
- model, prompt, negative_prompt, fast_generation, enabled_loras,
698
- lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
699
- img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
700
- img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
701
- resize_mode,
702
- scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
703
- image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
704
- refiner, vae
705
- ],
706
- [output_images]
707
- )
708
- with gr.Tab("SDXL"):
709
- gr.Label("To be implemented")
710
  with gr.Tab(label="🎡 Audio"):
711
  gr.Label("Coming soon!")
712
  with gr.Tab(label="🎬 Video"):
713
  gr.Label("Coming soon!")
714
  with gr.Tab(label="πŸ“„ Text"):
715
  gr.Label("Coming soon!")
716
-
717
-
718
  demo.launch(
719
  share=False,
720
  debug=True,
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from src.ui import (
5
+ image_tab,
6
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
 
9
  css = """
10
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
11
  body {
 
27
  """
28
 
29
 
 
 
 
 
 
30
  # Main Gradio app
31
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
32
  # Header
 
40
  # Tabs
41
  with gr.Tabs():
42
  with gr.Tab(label="πŸ–ΌοΈ Image"):
43
+ image_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Tab(label="🎡 Audio"):
45
  gr.Label("Coming soon!")
46
  with gr.Tab(label="🎬 Video"):
47
  gr.Label("Coming soon!")
48
  with gr.Tab(label="πŸ“„ Text"):
49
  gr.Label("Coming soon!")
50
+
 
51
  demo.launch(
52
  share=False,
53
  debug=True,