File size: 1,868 Bytes
42ae52a
 
daf9c75
 
 
 
 
 
 
 
 
 
42ae52a
 
daf9c75
 
42ae52a
daf9c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42ae52a
 
daf9c75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

import torch
from diffusers import (
    AutoPipelineForText2Image,
    DiffusionPipeline,
    AutoencoderKL,
    FluxControlNetModel,
    FluxMultiControlNetModel,
)
from diffusers.schedulers import *



def load_flux():
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Models
    models = [
        {
            "repo_id": "black-forest-labs/FLUX.1-dev",
            "loader": "flux",
            "compute_type": torch.bfloat16,
        }
    ]

    for model in models:
        try:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
                torch_dtype=model['compute_type'],
                safety_checker=None,
                variant="fp16"
            ).to(device)
        except:
            model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
                model['repo_id'],
                vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
                torch_dtype=model['compute_type'],
                safety_checker=None
            ).to(device)

        model["pipeline"].enable_model_cpu_offload()

    # VAE n Refiner
    flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)

    # ControlNet
    controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
        "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
        torch_dtype=torch.bfloat16
    ).to(device)])

    return device, models, flux_vae, controlnet


device, models, flux_vae, controlnet = load_flux()