Refactor model loading in load_models.py to use ControlNetModel instead of FluxControlNetModel
b51595d
import torch | |
from diffusers import ( | |
AutoPipelineForText2Image, | |
AutoencoderKL, | |
FluxControlNetModel, | |
ControlNetModel, | |
FluxMultiControlNetModel, | |
) | |
from diffusers.schedulers import * | |
from config import Config | |
def init_sys(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
models = Config.IMAGES_MODELS | |
for model in models: | |
try: | |
model['pipeline'] = AutoPipelineForText2Image.from_pretrained( | |
model['repo_id'], | |
torch_dtype=model['compute_type'], | |
safety_checker=None, | |
variant="fp16" | |
).to(device) | |
except: | |
model['pipeline'] = AutoPipelineForText2Image.from_pretrained( | |
model['repo_id'], | |
torch_dtype=model['compute_type'], | |
safety_checker=None | |
).to(device) | |
model['pipeline'].enable_model_cpu_offload() | |
# VAE n Refiner | |
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) | |
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device) | |
refiner = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16).to(device) | |
# ControlNet | |
controlnets = Config.IMAGES_CONTROLNETS | |
for controlnet in controlnets: | |
if controlnet['loader'] == 'flux-multi': | |
controlnet['controlnet'] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( | |
controlnet['repo_id'], | |
torch_dtype=controlnet['compute_type'] | |
).to(device)]) | |
elif controlnet['loader'] == 'sdxl': | |
controlnet['controlnet'] = ControlNetModel.from_pretrained( | |
controlnet['repo_id'], | |
torch_dtype=controlnet['compute_type'] | |
).to(device) | |
elif controlnet['loader'] == 'flux': | |
controlnet['controlnet'] = FluxControlNetModel.from_pretrained( | |
controlnet['repo_id'], | |
torch_dtype=controlnet['compute_type'] | |
).to(device) | |
return device, models, flux_vae, sdxl_vae, refiner, controlnets | |
device, models, flux_vae, sdxl_vae, refiner, controlnets = init_sys() | |