ddpm / app.py
lanzhiwang's picture
debug 4
c4f95ed
raw
history blame
1.94 kB
import gradio as gr
from diffusers import DiffusionPipeline
import torch
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import numpy as np
def erzeuge(prompt):
return pipeline(prompt).images # [0]
def erzeuge_komplex(prompt):
scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
scheduler.set_timesteps(50)
sample_size = model.config.sample_size
noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
input = noise
for t in scheduler.timesteps:
with torch.no_grad():
noisy_residual = model(input, t).sample
prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
input = prev_noisy_sample
image = (input / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return image
# pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cat-256")
pipeline.to("cuda")
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Deine Beschreibung:",
show_label=False,
max_lines=1,
placeholder="Bildbeschreibung",
).style(
container=False,
)
btn = gr.Button("erzeuge Bild").style(full_width=False, min_width=100)
gallery = gr.Gallery(
label="Erzeugtes Bild", show_label=False, elem_id="gallery"
).style(columns=[2], rows=[2], object_fit="contain", height="auto")
btn.click(erzeuge, inputs=[text], outputs=[gallery])
text.submit(erzeuge, inputs=[text], outputs=[gallery])
if __name__ == "__main__":
demo.launch()