aai / tabs /images /load_models.py
barreloflube's picture
Refactor app.py to add audio tab and update gradio UI
acde4c3
raw
history blame
2.35 kB
import torch
from diffusers import (
AutoPipelineForText2Image,
AutoencoderKL,
FluxControlNetModel,
ControlNetModel,
FluxMultiControlNetModel,
)
from diffusers.schedulers import *
from config import Config
def init_sys():
device = "cuda" if torch.cuda.is_available() else "cpu"
models = Config.IMAGES_MODELS
for model in models:
try:
model['pipeline'] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
torch_dtype=model['compute_type'],
safety_checker=None,
variant="fp16"
).to(device)
except:
model['pipeline'] = AutoPipelineForText2Image.from_pretrained(
model['repo_id'],
torch_dtype=model['compute_type'],
safety_checker=None
).to(device)
model['pipeline'].enable_model_cpu_offload()
# VAE n Refiner
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
refiner = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16).to(device)
# ControlNet
controlnets = Config.IMAGES_CONTROLNETS
for controlnet in controlnets:
if controlnet['loader'] == 'flux-multi':
controlnet['controlnet'] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
controlnet['repo_id'],
torch_dtype=controlnet['compute_type']
).to(device)])
elif controlnet['loader'] == 'sdxl':
controlnet['controlnet'] = ControlNetModel.from_pretrained(
controlnet['repo_id'],
torch_dtype=controlnet['compute_type']
).to(device)
elif controlnet['loader'] == 'flux':
controlnet['controlnet'] = FluxControlNetModel.from_pretrained(
controlnet['repo_id'],
torch_dtype=controlnet['compute_type']
).to(device)
return device, models, flux_vae, sdxl_vae, refiner, controlnets
device, models, flux_vae, sdxl_vae, refiner, controlnets = init_sys()