|
|
|
import torch |
|
from diffusers import ( |
|
AutoPipelineForText2Image, |
|
DiffusionPipeline, |
|
AutoencoderKL, |
|
FluxControlNetModel, |
|
FluxMultiControlNetModel, |
|
) |
|
from diffusers.schedulers import * |
|
|
|
|
|
|
|
def load_flux(): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
models = [ |
|
{ |
|
"repo_id": "black-forest-labs/FLUX.1-dev", |
|
"loader": "flux", |
|
"compute_type": torch.bfloat16, |
|
} |
|
] |
|
|
|
for model in models: |
|
try: |
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), |
|
torch_dtype=model['compute_type'], |
|
safety_checker=None, |
|
variant="fp16" |
|
).to(device) |
|
except: |
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), |
|
torch_dtype=model['compute_type'], |
|
safety_checker=None |
|
).to(device) |
|
|
|
model["pipeline"].enable_model_cpu_offload() |
|
|
|
|
|
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) |
|
|
|
|
|
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( |
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", |
|
torch_dtype=torch.bfloat16 |
|
).to(device)]) |
|
|
|
return device, models, flux_vae, controlnet |
|
|
|
|
|
device, models, flux_vae, controlnet = load_flux() |
|
|