File size: 4,003 Bytes
c232276
 
 
 
 
 
 
 
8236b2b
 
c232276
98605c5
b841fb8
 
 
 
c232276
 
98605c5
c232276
 
 
 
b841fb8
 
 
98605c5
 
 
 
c232276
b841fb8
c232276
 
8236b2b
 
 
c232276
 
 
 
 
 
 
 
8236b2b
b841fb8
c232276
9f35428
93cdc3b
9f35428
 
93cdc3b
beec98c
c232276
b841fb8
 
c232276
b841fb8
c232276
 
 
 
 
 
 
 
 
 
 
 
 
 
b841fb8
c232276
 
 
b841fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from diffusers import (
    DiffusionPipeline,
    AutoencoderTiny,
    AutoencoderKL,
    AutoPipelineForImage2Image,
)
from flux_app.config import DTYPE, DEVICE, BASE_MODEL, TAEF1_MODEL, MAX_SEED  # Absolute import
from flux_app.utilities import calculate_shift, retrieve_timesteps, load_image_from_path, calculateDuration  # Absolute import
from flux_app.lora_handling import flux_pipe_call_that_returns_an_iterable_of_images  # Absolute import
import time
from huggingface_hub import login
import spaces
# Ensure CUDA is available
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please run on a GPU-enabled environment.")

class ModelManager:
    def __init__(self, hf_token=None):
        self.pipe = None
        self.pipe_i2i = None
        self.good_vae = None
        self.taef1 = None
        
        # Clear CUDA memory cache before loading models
        torch.cuda.empty_cache()

        if hf_token:
            login(token=hf_token)  # Log in with the provided token

        self.initialize_models()
    
    def initialize_models(self):
        """Initializes the diffusion pipelines and autoencoders."""
        self.taef1 = AutoencoderTiny.from_pretrained(TAEF1_MODEL, torch_dtype=DTYPE).to(DEVICE)
        self.good_vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae", torch_dtype=DTYPE).to(DEVICE)
        self.pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=DTYPE, vae=self.taef1).to(DEVICE)
        self.pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
            BASE_MODEL,
            vae=self.good_vae,
            transformer=self.pipe.transformer,
            text_encoder=self.pipe.text_encoder,
            tokenizer=self.pipe.tokenizer,
            text_encoder_2=self.pipe.text_encoder_2,
            tokenizer_2=self.pipe.tokenizer_2,
            torch_dtype=DTYPE
        ).to(DEVICE)

        setattr(self.pipe, "flux_pipe_call_that_returns_an_iterable_of_images", self.process_images)

    def process_images(self, *args, **kwargs):
        return flux_pipe_call_that_returns_an_iterable_of_images(self.pipe, *args, **kwargs)
    
    @spaces.GPU(duration=100)
    def generate_image(self, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
        """Generates an image using the FLUX pipeline."""
        self.pipe.to(DEVICE)  # Ensure pipeline is on GPU
        generator = torch.Generator(device=DEVICE).manual_seed(seed)
        
        with calculateDuration("Generating image"):
            for img in self.pipe.flux_pipe_call_that_returns_an_iterable_of_images(
                prompt=prompt_mash,
                num_inference_steps=steps,
                guidance_scale=cfg_scale,
                width=width,
                height=height,
                generator=generator,
                joint_attention_kwargs={"scale": lora_scale},
                output_type="pil",
                good_vae=self.good_vae,
            ):
                yield img
    def generate_image_to_image(self, prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
        """Generates an image using image-to-image processing."""
        generator = torch.Generator(device=DEVICE).manual_seed(seed)
        self.pipe_i2i.to(DEVICE)
        image_input = load_image_from_path(image_input_path)
        
        final_image = self.pipe_i2i(
            prompt=prompt_mash,
            image=image_input,
            strength=image_strength,
            num_inference_steps=steps,
            guidance_scale=cfg_scale,
            width=width,
            height=height,
            generator=generator,
            joint_attention_kwargs={"scale": lora_scale},
            output_type="pil",
        ).images[0]
        return final_image

# Ensure the pipeline is properly initialized when running
if __name__ == "__main__":
    model_manager = ModelManager()
    print("Model Manager initialized successfully with GPU support.")