File size: 2,539 Bytes
6aae8e6
 
 
 
 
c3ac0c7
 
6aae8e6
 
 
 
 
 
 
 
 
 
deff1bd
6aae8e6
 
 
deff1bd
6aae8e6
 
 
 
 
 
 
 
 
 
 
c018f93
6aae8e6
 
 
 
c018f93
6aae8e6
 
25c41f7
6aae8e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1c45fe
6aae8e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd9143f
 
6aae8e6
 
 
dd9143f
 
 
 
6aae8e6
 
 
 
d1c45fe
6aae8e6
 
 
 
 
 
 
 
 
 
 
 
c3ac0c7
6aae8e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python

from __future__ import annotations

import os
import random

import gradio as gr
import numpy as np
import torch
from diffusers import AutoencoderKL, DiffusionPipeline

MAX_SEED = np.iinfo(np.int32).max
MAX_IMG_SIZE = 4096

device = torch.device("cpu")

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32)
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float32,
    use_safetensors=True,
    variant="fp16",
)

def random_seed(seed: int, randomize: bool):
    if randomize:
        seed = random.randint(0, MAX_SEED)

    return seed

def generate(
    prompt: str = "a cat",
    seed: int = 0,
    width: int = 1024,
    height: int = 1024,
):
    generator = torch.Generator().manual_seed(seed)

    return pipe(
        prompt=prompt,
        negative_prompt=None,
        prompt_2=None,
        negative_prompt_2=None,
        width=width,
        height=height,
        guidance_scale=5.0,
        num_inference_steps=10,
        generator=generator,
        output_type="pil",
    ).images[0]


with gr.Blocks() as instance:
    gr.Markdown('# Stable Diffusion')

    with gr.Group():
        prompt = gr.Textbox(
            label="Prompt"
        )

        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
        )

        is_random_seed = gr.Checkbox(label="Random seed", value=True)

        with gr.Row():
            width = gr.Slider(
                label="Width",
                minimum=256,
                maximum=MAX_IMG_SIZE,
                step=32,
                value=1024,
            )

            height = gr.Slider(
                label="Height",
                minimum=256,
                maximum=MAX_IMG_SIZE,
                step=32,
                value=1024,
            )

    result = gr.Image(label="Result", show_label=False)
    
    submit = gr.Button("Generate Image")

    gr.on(
        triggers=[
            submit.click
            # seed.change,
            # width.change,
            # height.change,
        ],
        fn=random_seed,
        inputs=[
            seed, 
            is_random_seed
        ],
        outputs=seed
    ).then(
        fn=generate,
        inputs=[
            prompt,
            seed,
            width,
            height,
        ],
        outputs=result
    )

if __name__ == "__main__":
    instance.launch()