File size: 1,944 Bytes
e6a0177
2e1269a
 
 
 
 
 
 
 
c4f95ed
 
2e1269a
c4f95ed
2e1269a
 
 
c4f95ed
2e1269a
 
 
c4f95ed
2e1269a
 
 
 
 
c4f95ed
2e1269a
 
 
 
 
 
c4f95ed
 
2e1269a
c4f95ed
2e1269a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4f95ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from diffusers import DiffusionPipeline
import torch
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import numpy as np


def erzeuge(prompt):
    return pipeline(prompt).images  # [0]


def erzeuge_komplex(prompt):
    scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
    model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
    scheduler.set_timesteps(50)

    sample_size = model.config.sample_size
    noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
    input = noise

    for t in scheduler.timesteps:
        with torch.no_grad():
            noisy_residual = model(input, t).sample
            prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
            input = prev_noisy_sample

    image = (input / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
    image = Image.fromarray((image * 255).round().astype("uint8"))
    return image


# pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cat-256")

pipeline.to("cuda")


with gr.Blocks() as demo:
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            text = gr.Textbox(
                label="Deine Beschreibung:",
                show_label=False,
                max_lines=1,
                placeholder="Bildbeschreibung",
            ).style(
                container=False,
            )
            btn = gr.Button("erzeuge Bild").style(full_width=False, min_width=100)

        gallery = gr.Gallery(
            label="Erzeugtes Bild", show_label=False, elem_id="gallery"
        ).style(columns=[2], rows=[2], object_fit="contain", height="auto")

    btn.click(erzeuge, inputs=[text], outputs=[gallery])
    text.submit(erzeuge, inputs=[text], outputs=[gallery])

if __name__ == "__main__":
    demo.launch()