ChenWu98's picture
Update app.py
5a0d186
raw
history blame
5.89 kB
from diffusers import CycleDiffusionPipeline, DDIMScheduler
import gradio as gr
import torch
from PIL import Image
import utils
import streamlit as st
is_colab = utils.is_google_colab()
if False:
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False, set_alpha_to_one=False)
model_id_or_path = "CompVis/stable-diffusion-v1-4"
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
use_auth_token=st.secrets["USER_TOKEN"],
scheduler=scheduler)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
device = "GPU πŸ”₯" if torch.cuda.is_available() else "CPU πŸ₯Ά"
def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
width=512, height=512, seed=0, img=None, strength=0.7):
torch.manual_seed(seed)
ratio = min(height / img.height, width / img.width)
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
result = pipe(prompt=target_prompt,
source_prompt=source_prompt,
init_image=img,
num_inference_steps=num_inference_steps,
eta=0.1,
strength=strength,
guidance_scale=guidance_scale,
source_guidance_scale=source_guidance_scale,
).images[0]
return replace_nsfw_images(result)
def replace_nsfw_images(results):
for i in range(len(results.images)):
if results.nsfw_content_detected[i]:
results.images[i] = Image.open("nsfw.png")
return results.images[0]
css = """.finetuned-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.finetuned-diffusion-div div h1{font-weight:900;margin-bottom:7px}.finetuned-diffusion-div p{margin-bottom:10px;font-size:94%}.finetuned-diffusion-div p a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
f"""
<div class="finetuned-diffusion-div">
<div>
<h1>CycleDiffusion with Stable Diffusion</h1>
</div>
<p>
Demo for CycleDiffusion with Stable Diffusion, built with Diffusers 🧨 by HuggingFace πŸ€—.
</p>
<p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/qunash/42112fb104509c24fd3aa6d1c11dd6e0/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://camo.githubusercontent.com/84f0493939e0c4de4e6dbe113251b4bfb5353e57134ffd9fcab6b8714514d4d1/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667"></a></p>
Running on <b>{device}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=55):
with gr.Group():
with gr.Row():
generate = gr.Button(value="Generate").style(rounded=(False, True, True, False))
img = gr.Image(label="Source image", height=256, tool="editor", type="pil")
image_out = gr.Image(height=512)
# gallery = gr.Gallery(
# label="Generated images", show_label=False, elem_id="gallery"
# ).style(grid=[1], height="auto")
with gr.Column(scale=45):
with gr.Tab("Options"):
with gr.Group():
source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
with gr.Row():
source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
with gr.Row():
num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
with gr.Row():
width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
width, height, seed, img, strength]
generate.click(inference, inputs=inputs, outputs=image_out)
ex = gr.Examples(
[
["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100, 512, 512, 0, "images/astronaut_horse.png", 0.8],
],
[source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
width, height, seed, img, strength],
image_out, inference, cache_examples=False)
gr.Markdown('''
Space by: [![Twitter Follow](https://img.shields.io/twitter/follow/ChenHenryWu?style=social)](https://twitter.com/ChenHenryWu)
![visitors](https://visitor-badge.glitch.me/badge?page_id=ChenWu98.CycleDiffusion)
''')
if not is_colab:
demo.queue(concurrency_count=1)
demo.launch(debug=is_colab, share=is_colab)