|
import torch |
|
from diffusers import ( |
|
AutoPipelineForText2Image, |
|
AutoencoderKL, |
|
FluxControlNetModel, |
|
ControlNetModel, |
|
FluxMultiControlNetModel, |
|
) |
|
from diffusers.schedulers import * |
|
|
|
from config import Config |
|
|
|
|
|
def init_sys(): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
models = Config.IMAGES_MODELS |
|
|
|
for model in models: |
|
try: |
|
model['pipeline'] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
torch_dtype=model['compute_type'], |
|
safety_checker=None, |
|
variant="fp16" |
|
).to(device) |
|
except: |
|
model['pipeline'] = AutoPipelineForText2Image.from_pretrained( |
|
model['repo_id'], |
|
torch_dtype=model['compute_type'], |
|
safety_checker=None |
|
).to(device) |
|
model['pipeline'].enable_model_cpu_offload() |
|
|
|
|
|
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) |
|
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device) |
|
refiner = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16).to(device) |
|
|
|
|
|
controlnets = Config.IMAGES_CONTROLNETS |
|
for controlnet in controlnets: |
|
if controlnet['loader'] == 'flux-multi': |
|
controlnet['controlnet'] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( |
|
controlnet['repo_id'], |
|
torch_dtype=controlnet['compute_type'] |
|
).to(device)]) |
|
elif controlnet['loader'] == 'sdxl': |
|
controlnet['controlnet'] = ControlNetModel.from_pretrained( |
|
controlnet['repo_id'], |
|
torch_dtype=controlnet['compute_type'] |
|
).to(device) |
|
elif controlnet['loader'] == 'flux': |
|
controlnet['controlnet'] = FluxControlNetModel.from_pretrained( |
|
controlnet['repo_id'], |
|
torch_dtype=controlnet['compute_type'] |
|
).to(device) |
|
|
|
return device, models, flux_vae, sdxl_vae, refiner, controlnets |
|
|
|
device, models, flux_vae, sdxl_vae, refiner, controlnets = init_sys() |
|
|