File size: 5,399 Bytes
353dd90
 
 
 
 
 
 
f2ab5fc
 
353dd90
f2ab5fc
353dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2ab5fc
353dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gradio as gr


# Initialize the prior and decoder pipelines
prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
prior.enable_xformers_memory_efficient_attention()

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
decoder.enable_xformers_memory_efficient_attention()

def generate_images(
    prompt="a photo of a girl",
    negative_prompt="bad,ugly,deformed",
    height=1024,
    width=1024, 
    guidance_scale=4.0, 
    prior_inference_steps=20, 
    decoder_inference_steps=10
    ):
    """
    Generates images based on a given prompt using Stable Diffusion models on CUDA device.

    Parameters:
    - prompt (str): The prompt to generate images for.
    - negative_prompt (str): The negative prompt to guide image generation away from.
    - height (int): The height of the generated images.
    - width (int): The width of the generated images.
    - guidance_scale (float): The scale of guidance for the image generation.
    - prior_inference_steps (int): The number of inference steps for the prior model.
    - decoder_inference_steps (int): The number of inference steps for the decoder model.

    Returns:
    - List[PIL.Image]: A list of generated PIL Image objects.
    """

    # Generate image embeddings using the prior model
    prior_output = prior(
        prompt=prompt,
        height=height,
        width=width,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_images_per_prompt=1,
        num_inference_steps=prior_inference_steps
    )

    # Generate images using the decoder model and the embeddings from the prior model
    decoder_output = decoder(
        image_embeddings=prior_output.image_embeddings.half(),
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=0.0,  # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
        output_type="pil",
        num_inference_steps=decoder_inference_steps
    ).images

    return decoder_output


def web_demo():
    with gr.Blocks():
        with gr.Row():
            with gr.Column():
                text2image_prompt = gr.Textbox(
                    lines=1,
                    placeholder="Prompt",
                    show_label=False,
                )

                text2image_negative_prompt = gr.Textbox(
                    lines=1,
                    placeholder="Negative Prompt",
                    show_label=False,
                )
                with gr.Row():
                    with gr.Column():
                        text2image_height = gr.Slider(
                            minimum=128,
                            maximum=1280,
                            step=32,
                            value=512,
                            label="Image Height",
                        )

                        text2image_width = gr.Slider(
                            minimum=128,
                            maximum=1280,
                            step=32,
                            value=512,
                            label="Image Width",
                        )
                        with gr.Row():
                            with gr.Column():
                                text2image_guidance_scale = gr.Slider(
                                    minimum=0.1,
                                    maximum=15,
                                    step=0.1,
                                    value=4.0,
                                    label="Guidance Scale",
                                )                
                                text2image_prior_inference_step = gr.Slider(
                                    minimum=1,
                                    maximum=50,
                                    step=1,
                                    value=20,
                                    label="Prior Inference Step",
                                )                
                                
                                text2image_decoder_inference_step = gr.Slider(
                                    minimum=1,
                                    maximum=50,
                                    step=1,
                                    value=10,
                                    label="Decoder Inference Step",
                                )               
                text2image_predict = gr.Button(value="Generate Image")
                
            with gr.Column():
                output_image = gr.Gallery(
                    label="Generated images",
                    show_label=False,
                    elem_id="gallery",
                )
                
            text2image_predict.click(
                fn=generate_images,
                inputs=[
                    text2image_prompt,
                    text2image_negative_prompt,
                    text2image_height,
                    text2image_width,
                    text2image_guidance_scale,
                    text2image_prior_inference_step,
                    text2image_decoder_inference_step
                ],
                outputs=output_image,
            )