barreloflube commited on
Commit
07dc8e6
·
1 Parent(s): 3a5022f

Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes

Browse files
modules/helpers/common_helpers.py CHANGED
@@ -7,8 +7,6 @@ from PIL import Image
7
  from diffusers.schedulers import *
8
  from controlnet_aux.processor import Processor
9
 
10
- from .flux_helpers import ControlNetReq
11
-
12
 
13
  class ControlNetReq(BaseModel):
14
  controlnets: List[str] # ["canny", "tile", "depth"]
 
7
  from diffusers.schedulers import *
8
  from controlnet_aux.processor import Processor
9
 
 
 
10
 
11
  class ControlNetReq(BaseModel):
12
  controlnets: List[str] # ["canny", "tile", "depth"]
modules/helpers/flux_helpers.py CHANGED
@@ -36,21 +36,20 @@ def load_sd():
36
  try:
37
  model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
38
  model['repo_id'],
39
- vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device),
40
- torch_dtype = model['compute_type'],
41
- safety_checker = None,
42
- variant = "fp16"
43
  ).to(device)
44
  except:
45
  model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
46
  model['repo_id'],
47
- vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
48
- torch_dtype = model['compute_type'],
49
- safety_checker = None
50
  ).to(device)
51
-
52
- model["pipeline"].enable_model_cpu_offload()
53
 
 
54
 
55
  # VAE n Refiner
56
  flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
@@ -58,13 +57,12 @@ def load_sd():
58
  refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
59
  refiner.enable_model_cpu_offload()
60
 
61
-
62
  # ControlNet
63
  controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
64
  "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
65
  torch_dtype=torch.bfloat16
66
  ).to(device)])
67
-
68
  return device, models, flux_vae, sdxl_vae, refiner, controlnet
69
 
70
 
@@ -74,11 +72,11 @@ device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
74
  def get_control_mode(controlnet_config: ControlNetReq):
75
  control_mode = []
76
  layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
77
-
78
  for c in controlnet_config.controlnets:
79
  if c in layers:
80
  control_mode.append(layers.index(c))
81
-
82
  return control_mode
83
 
84
 
@@ -88,14 +86,12 @@ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
88
  pipe_args = {
89
  "pipeline": m['pipeline'],
90
  }
91
-
92
-
93
  # Set ControlNet config
94
  if request.controlnet_config:
95
  pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
96
  pipe_args["controlnet"] = [controlnet]
97
-
98
-
99
  # Choose Pipeline Mode
100
  if isinstance(request, BaseReq):
101
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
@@ -103,35 +99,32 @@ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
103
  pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
104
  elif isinstance(request, BaseInpaintReq):
105
  pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
106
-
107
-
108
  # Enable or Disable Refiner
109
  if request.vae:
110
  pipe_args["pipeline"].vae = flux_vae
111
  elif not request.vae:
112
  pipe_args["pipeline"].vae = None
113
-
114
-
115
  # Set Scheduler
116
  pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
117
-
118
-
119
  # Set Loras
120
  if request.loras:
121
  for i, lora in enumerate(request.loras):
122
  pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}")
123
  adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
124
  adapter_weights = [lora['weight'] for lora in request.loras]
125
-
126
  if request.fast_generation:
127
  hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
128
  hyper_weight = 0.125
129
  pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
130
  adapter_names.append("hyper_lora")
131
  adapter_weights.append(hyper_weight)
132
-
133
  pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
134
-
135
  return pipe_args
136
 
137
 
@@ -145,7 +138,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
145
  pipeline = pipe_args["pipeline"]
146
  try:
147
  positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
148
-
149
  # Common Args
150
  args = {
151
  'prompt_embeds': positive_prompt_embeds,
@@ -157,28 +150,28 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
157
  'guidance_scale': request.guidance_scale,
158
  'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
159
  }
160
-
161
  if request.controlnet_config:
162
  args['control_mode'] = get_control_mode(request.controlnet_config)
163
  args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
164
  args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
165
-
166
  if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
167
  args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
168
  args['strength'] = request.strength
169
-
170
  if isinstance(request, BaseInpaintReq):
171
  args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
172
-
173
  # Generate
174
  images = pipeline(**args).images
175
-
176
  # Refiner
177
  if request.refiner:
178
  images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
179
-
180
  cleanup(pipeline, request.loras)
181
-
182
  return images
183
  except Exception as e:
184
  cleanup(pipeline, request.loras)
 
36
  try:
37
  model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
38
  model['repo_id'],
39
+ vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
40
+ torch_dtype=model['compute_type'],
41
+ safety_checker=None,
42
+ variant="fp16"
43
  ).to(device)
44
  except:
45
  model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
46
  model['repo_id'],
47
+ vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
48
+ torch_dtype=model['compute_type'],
49
+ safety_checker=None
50
  ).to(device)
 
 
51
 
52
+ model["pipeline"].enable_model_cpu_offload()
53
 
54
  # VAE n Refiner
55
  flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
 
57
  refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
58
  refiner.enable_model_cpu_offload()
59
 
 
60
  # ControlNet
61
  controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
62
  "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
63
  torch_dtype=torch.bfloat16
64
  ).to(device)])
65
+
66
  return device, models, flux_vae, sdxl_vae, refiner, controlnet
67
 
68
 
 
72
  def get_control_mode(controlnet_config: ControlNetReq):
73
  control_mode = []
74
  layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
75
+
76
  for c in controlnet_config.controlnets:
77
  if c in layers:
78
  control_mode.append(layers.index(c))
79
+
80
  return control_mode
81
 
82
 
 
86
  pipe_args = {
87
  "pipeline": m['pipeline'],
88
  }
89
+
 
90
  # Set ControlNet config
91
  if request.controlnet_config:
92
  pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
93
  pipe_args["controlnet"] = [controlnet]
94
+
 
95
  # Choose Pipeline Mode
96
  if isinstance(request, BaseReq):
97
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
 
99
  pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
100
  elif isinstance(request, BaseInpaintReq):
101
  pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
102
+
 
103
  # Enable or Disable Refiner
104
  if request.vae:
105
  pipe_args["pipeline"].vae = flux_vae
106
  elif not request.vae:
107
  pipe_args["pipeline"].vae = None
108
+
 
109
  # Set Scheduler
110
  pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
111
+
 
112
  # Set Loras
113
  if request.loras:
114
  for i, lora in enumerate(request.loras):
115
  pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}")
116
  adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
117
  adapter_weights = [lora['weight'] for lora in request.loras]
118
+
119
  if request.fast_generation:
120
  hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
121
  hyper_weight = 0.125
122
  pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
123
  adapter_names.append("hyper_lora")
124
  adapter_weights.append(hyper_weight)
125
+
126
  pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
127
+
128
  return pipe_args
129
 
130
 
 
138
  pipeline = pipe_args["pipeline"]
139
  try:
140
  positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
141
+
142
  # Common Args
143
  args = {
144
  'prompt_embeds': positive_prompt_embeds,
 
150
  'guidance_scale': request.guidance_scale,
151
  'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
152
  }
153
+
154
  if request.controlnet_config:
155
  args['control_mode'] = get_control_mode(request.controlnet_config)
156
  args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
157
  args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
158
+
159
  if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
160
  args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
161
  args['strength'] = request.strength
162
+
163
  if isinstance(request, BaseInpaintReq):
164
  args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
165
+
166
  # Generate
167
  images = pipeline(**args).images
168
+
169
  # Refiner
170
  if request.refiner:
171
  images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
172
+
173
  cleanup(pipeline, request.loras)
174
+
175
  return images
176
  except Exception as e:
177
  cleanup(pipeline, request.loras)