text2image / app.py
iohanngrig's picture
Update app.py
1b5361f verified
raw
history blame
2.91 kB
import streamlit as st
from io import BytesIO
from typing import Literal
from diffusers import StableDiffusionPipeline
import torch
import time
seed = 42
generator = torch.manual_seed(seed)
NUM_ITERS_TO_RUN = 1
NUM_INFERENCE_STEPS = 5
NUM_IMAGES_PER_PROMPT = 1
def text2image(
prompt: str,
repo_id: Literal[
"dreamlike-art/dreamlike-photoreal-2.0",
"hakurei/waifu-diffusion",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
],
):
start = time.time()
if torch.cuda.is_available():
print("Using GPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
else:
print("Using CPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float32,
use_safetensors=True,
)
for _ in range(NUM_ITERS_TO_RUN):
images = pipeline(
prompt,
num_inference_steps=NUM_INFERENCE_STEPS,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time()
return images[0], start, end
def app():
st.header("Text-to-image Web App")
st.subheader("Powered by Hugging Face")
user_input = st.text_area(
"Enter your text prompt below and click the button to submit."
)
option = st.selectbox(
"Select model (in order of processing time)",
(
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"prompthero/openjourney",
"hakurei/waifu-diffusion",
"stabilityai/stable-diffusion-2-1",
"dreamlike-art/dreamlike-photoreal-2.0",
),
)
with st.form("my_form"):
submit = st.form_submit_button(label="Submit text prompt")
if submit:
with st.spinner(text="Generating image ... It may take up to 20 minutes."):
im, start, end = text2image(prompt=user_input, repo_id=option)
buf = BytesIO()
im.save(buf, format="PNG")
byte_im = buf.getvalue()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
st.success(
"Processing time: {:0>2}:{:0>2}:{:05.2f}.".format(
int(hours), int(minutes), seconds
)
)
st.image(im)
st.download_button(
label="Click here to download",
data=byte_im,
file_name="generated_image.png",
mime="image/png",
)
if __name__ == "__main__":
app()