Spaces:
Running
Running
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import spaces | |
import torch | |
from tqdm import tqdm | |
from controlnet import QRControlNet | |
from game_of_life import GameOfLife | |
from utils import resize_image, generate_image_from_grid | |
def generate_all_images( | |
gol_grids: list[np.array], | |
source_image: Image, | |
num_inference_steps: int, | |
controlnet_conditioning_scale: float, | |
strength: float, | |
prompt: str, | |
negative_prompt: str, | |
seed: int, | |
guidance_scale: float, | |
img_size: int, | |
): | |
# device = "mps" | |
# device = "cpu" | |
device = "cuda" | |
print(f"Using {device=}") | |
# Initialize the controlnet (this can take a while the first time it's run) | |
controlnet = QRControlNet(device=device) | |
controlnet_conditioning_scale = float(controlnet_conditioning_scale) | |
source_image = resize_image(source_image, resolution=img_size) | |
images = [] | |
for grid in tqdm(gol_grids): | |
grid_inverse = 1 - grid # invert the grid for controlnet | |
grid_inverse_image = generate_image_from_grid(grid_inverse, img_size=img_size) | |
image = controlnet.generate_image( | |
source_image=source_image, | |
control_image=grid_inverse_image, | |
num_inference_steps=num_inference_steps, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
strength=strength, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
seed=seed, | |
guidance_scale=guidance_scale, | |
img_size=img_size, | |
) | |
images.append(image) | |
return images | |
def make_gif(images: list[Image.Image], gif_path): | |
images[0].save( | |
gif_path, | |
save_all=True, | |
append_images=images[1:], | |
duration=200, # Duration between frames in milliseconds | |
loop=0, | |
) # Loop forever | |
return gif_path | |
def generate( | |
source_image, | |
prompt, | |
negative_prompt, | |
seed, | |
num_inference_steps, | |
num_gol_steps, | |
gol_grid_dim, | |
img_size, | |
controlnet_conditioning_scale, | |
strength, | |
guidance_scale, | |
): | |
# Compute the Game of Life first | |
gol = GameOfLife() | |
gol.set_random_state(dim=(gol_grid_dim, gol_grid_dim), p=0.5, seed=seed) | |
gol.generate_n_steps(n=num_gol_steps) | |
gol_grids = gol.game_history | |
# Generate the gif for the original Game of Life | |
gol_images = [ | |
generate_image_from_grid(grid, img_size=img_size) for grid in gol_grids | |
] | |
path_gol_gif = make_gif(gol_images, "gol_original.gif") | |
# Generate the gif for the ControlNet Game of Life | |
controlnet_images = generate_all_images( | |
gol_grids=gol_grids, | |
source_image=source_image, | |
num_inference_steps=num_inference_steps, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
strength=strength, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
seed=seed, | |
guidance_scale=guidance_scale, | |
img_size=img_size, | |
) | |
path_gol_controlnet = make_gif(controlnet_images, "gol_controlnet.gif") | |
return path_gol_controlnet, path_gol_gif | |
source_image = gr.Image(label="Source Image", type="pil", value="sky-gol-image.jpeg") | |
output_controlnet = gr.Image(label="ControlNet Game of Life") | |
output_gol = gr.Image(label="Original Game of Life") | |
prompt = gr.Textbox( | |
label="Prompt", value="clear sky with clouds, high quality, background 4k" | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="ugly, disfigured, low quality, blurry, nsfw, qr code", | |
) | |
seed = gr.Number(label="Seed", value=42) | |
num_inference_steps = gr.Number(label="Controlnet Inference Steps per frame", value=30) | |
num_gol_steps = gr.Slider( | |
label="Number of Game of Life Steps to Generate", | |
minimum=2, | |
maximum=100, | |
step=1, | |
value=6, | |
) | |
gol_grid_dim = gr.Number( | |
label="Game of Life Grid Dimension", | |
value=10, | |
) | |
img_size = gr.Number(label="Image Size (pixels)", value=512) | |
controlnet_conditioning_scale = gr.Slider( | |
label="Controlnet Conditioning Scale", minimum=0.1, maximum=10.0, value=2.0 | |
) | |
strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.9) | |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=100, value=20) | |
demo = gr.Interface( | |
fn=generate, | |
inputs=[ | |
source_image, | |
prompt, | |
negative_prompt, | |
seed, | |
num_inference_steps, | |
num_gol_steps, | |
gol_grid_dim, | |
img_size, | |
controlnet_conditioning_scale, | |
strength, | |
guidance_scale, | |
], | |
outputs=[output_controlnet, output_gol], | |
title="ControlNet Game of Life", | |
description="""Generate a Game of Life grid and then use ControlNet to enhance the image based on the grid, a reference image and a prompt. | |
For more information, check out this [blog post](https://www.jerpint.io/blog/diffusion-gol/). Generating frames can be slow and eat up GPU usage, for longer runtimes, you can checkout the [colab](https://colab.research.google.com/github/jerpint/jerpint.github.io/blob/master/colabs/gol_diffusion.ipynb) implementation. | |
""", | |
) | |
demo.launch(debug=True) | |