import torch from diffusers import ( AutoPipelineForText2Image, DiffusionPipeline, AutoencoderKL, FluxControlNetModel, FluxMultiControlNetModel, ) from diffusers.schedulers import * def load_flux(): # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cuda" if torch.cuda.is_available() else "cpu" # Models models = [ { "repo_id": "black-forest-labs/FLUX.1-dev", "loader": "flux", "compute_type": torch.bfloat16, } ] for model in models: try: model["pipeline"] = AutoPipelineForText2Image.from_pretrained( model['repo_id'], vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), torch_dtype=model['compute_type'], safety_checker=None, variant="fp16" ).to(device) except: model["pipeline"] = AutoPipelineForText2Image.from_pretrained( model['repo_id'], vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), torch_dtype=model['compute_type'], safety_checker=None ).to(device) model["pipeline"].enable_model_cpu_offload() # VAE n Refiner flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) # ControlNet controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", torch_dtype=torch.bfloat16 ).to(device)]) return device, models, flux_vae, controlnet device, models, flux_vae, controlnet = load_flux()