MaykaGR commited on
Commit
02fb232
·
verified ·
1 Parent(s): e33138b

Update comfy_pulid.py

Browse files
Files changed (1) hide show
  1. comfy_pulid.py +130 -4
comfy_pulid.py CHANGED
@@ -3,6 +3,17 @@ import random
3
  import sys
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
@@ -114,12 +125,94 @@ def import_custom_nodes() -> None:
114
 
115
  from nodes import NODE_CLASS_MAPPINGS
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  def generate_image(prompt, structure_image, style_image, depth_strength, style_strength):
119
  import_custom_nodes()
120
  with torch.inference_mode():
121
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
122
- vaeloader_10 = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
123
 
124
  dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
125
  dualcliploader_11 = dualcliploader.load_clip(
@@ -284,5 +377,38 @@ def generate_image(prompt, structure_image, style_image, depth_strength, style_s
284
  return saved_path
285
 
286
 
287
- #if __name__ == "__main__":
288
- # main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import sys
4
  from typing import Sequence, Mapping, Any, Union
5
  import torch
6
+ from comfy import model_management
7
+ from huggingface_hub import hf_hub_download
8
+ import spaces
9
+
10
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
11
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
12
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
13
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
14
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
15
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
16
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
17
 
18
 
19
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
 
125
 
126
  from nodes import NODE_CLASS_MAPPINGS
127
 
128
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
129
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
130
+
131
+ #To be added to `model_loaders` as it loads a model
132
+ dualcliploader_357 = dualcliploader.load_clip(
133
+ clip_name1="t5/t5xxl_fp16.safetensors",
134
+ clip_name2="clip_l.safetensors",
135
+ type="flux",
136
+ )
137
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
138
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
139
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
140
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
141
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
142
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
143
+
144
+ #To be added to `model_loaders` as it loads a model
145
+ vaeloader_359 = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
146
+
147
+ vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]()
148
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
149
+
150
+ #To be added to `model_loaders` as it loads a model
151
+ unetloader_358 = unetloader.load_unet(
152
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
153
+ )
154
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
155
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
156
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
157
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
158
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS[
159
+ "DownloadAndLoadDepthAnythingV2Model"
160
+ ]()
161
+
162
+ #To be added to `model_loaders` as it loads a model
163
+ downloadandloaddepthanythingv2model_437 = (
164
+ downloadandloaddepthanythingv2model.loadmodel(
165
+ model="depth_anything_v2_vitl_fp32.safetensors"
166
+ )
167
+ )
168
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS[
169
+ "InstructPixToPixConditioning"
170
+ ]()
171
+ text_multiline_454 = text_multiline.text_multiline(text="FLUX_Redux")
172
+ clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
173
+
174
+ #To be added to `model_loaders` as it loads a model
175
+ clipvisionloader_438 = clipvisionloader.load_clip(
176
+ clip_name="sigclip_vision_patch14_384.safetensors"
177
+ )
178
+ clipvisionencode = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
179
+ stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
180
+
181
+ #To be added to `model_loaders` as it loads a model
182
+ stylemodelloader_441 = stylemodelloader.load_style_model(
183
+ style_model_name="flux1-redux-dev.safetensors"
184
+ )
185
+ text_multiline = NODE_CLASS_MAPPINGS["Text Multiline"]()
186
+ emptylatentimage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
187
+ cr_conditioning_input_switch = NODE_CLASS_MAPPINGS[
188
+ "CR Conditioning Input Switch"
189
+ ]()
190
+ cr_model_input_switch = NODE_CLASS_MAPPINGS["CR Model Input Switch"]()
191
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
192
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
193
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
194
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
195
+ vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
196
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
197
+ imagecrop = NODE_CLASS_MAPPINGS["ImageCrop+"]()
198
+
199
+ #Add all the models that load a safetensors file
200
+ model_loaders = [dualcliploader_357, vaeloader_359, unetloader_358, clipvisionloader_438, stylemodelloader_441, downloadandloaddepthanythingv2model_437]
201
+
202
+ # Check which models are valid and how to best load them
203
+ valid_models = [
204
+ getattr(loader[0], 'patcher', loader[0])
205
+ for loader in model_loaders
206
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
207
+ ]
208
+
209
+ #Finally loads the models
210
+ model_management.load_models_gpu(valid_models)
211
+
212
 
213
  def generate_image(prompt, structure_image, style_image, depth_strength, style_strength):
214
  import_custom_nodes()
215
  with torch.inference_mode():
 
 
216
 
217
  dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
218
  dualcliploader_11 = dualcliploader.load_clip(
 
377
  return saved_path
378
 
379
 
380
+ if __name__ == "__main__":
381
+ with gr.Blocks() as app:
382
+ # Add a title
383
+ gr.Markdown("# FLUX Style Shaping")
384
+
385
+ with gr.Row():
386
+ with gr.Column():
387
+ # Add an input
388
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
389
+ # Add a `Row` to include the groups side by side
390
+ with gr.Row():
391
+ # First group includes structure image and depth strength
392
+ with gr.Group():
393
+ structure_image = gr.Image(label="Structure Image", type="filepath")
394
+ depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
395
+ # Second group includes style image and style strength
396
+ with gr.Group():
397
+ style_image = gr.Image(label="Style Image", type="filepath")
398
+ style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
399
+
400
+ # The generate button
401
+ generate_btn = gr.Button("Generate")
402
+
403
+ with gr.Column():
404
+ # The output image
405
+ output_image = gr.Image(label="Generated Image")
406
+
407
+ # When clicking the button, it will trigger the `generate_image` function, with the respective inputs
408
+ # and the output an image
409
+ generate_btn.click(
410
+ fn=generate_image,
411
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
412
+ outputs=[output_image]
413
+ )
414
+ app.launch(share=True)