Spaces:
Paused
Paused
import gradio | |
import torch | |
import numpy | |
from PIL import Image | |
from torchvision import transforms | |
from diffusers import StableDiffusionInpaintPipeline | |
from diffusers import DPMSolverMultistepScheduler | |
deviceStr = "cuda" if torch.cuda.is_available() else "cpu" | |
device = torch.device(deviceStr) | |
if deviceStr == "cuda": | |
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
safety_checker=lambda images, **kwargs: (images, False)) | |
pipeline.to(device) | |
pipeline.enable_xformers_memory_efficient_attention() | |
latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16) | |
else: | |
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", | |
safety_checker=lambda images, **kwargs: (images, False)) | |
latents = torch.randn((1, 4, 64, 64), device=device) | |
imageSize = (512, 512) | |
lastImage = Image.new(mode="RGB", size=imageSize) | |
lastSeed = 512 | |
generator = torch.Generator(device).manual_seed(512) | |
def diffuse(staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed): | |
global latents, lastSeed, generator, deviceStr, lastImage | |
if mask is None or pauseInference is True: | |
return lastImage | |
if staticLatents is False: | |
if deviceStr == "cuda": | |
latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16) | |
else: | |
latents = torch.randn((1, 4, 64, 64), device=device) | |
if lastSeed != seed: | |
generator = torch.Generator(device).manual_seed(seed) | |
lastSeed = seed | |
newImage = pipeline(prompt=prompt, | |
negative_prompt=negativePrompt, | |
image=inputImage, | |
mask_image=mask, | |
guidance_scale=guidanceScale, | |
num_inference_steps=numInferenceSteps, | |
latents=latents, | |
generator=generator).images[0] | |
lastImage = newImage | |
return newImage | |
defaultMask = Image.open("assets/masks/sphere.png") | |
prompt = gradio.Textbox(label="Prompt", placeholder="A person in a room", lines=3) | |
negativePrompt = gradio.Textbox(label="Negative Prompt", placeholder="Text", lines=3) | |
inputImage = gradio.Image(label="Input Feed", source="webcam", shape=[512,512], streaming=True) | |
mask = gradio.Image(label="Mask", type="pil", value=defaultMask) | |
outputImage = gradio.Image(label="Extrapolated Field of View") | |
guidanceScale = gradio.Slider(label="Guidance Scale", maximum=1, value=0.75) | |
numInferenceSteps = gradio.Slider(label="Number of Inference Steps", maximum=100, value=25) | |
seed = gradio.Slider(label="Generator Seed", maximum=10000, value=4096) | |
staticLatents =gradio.Checkbox(label="Static Latents", value=True) | |
pauseInference = gradio.Checkbox(label="Pause Inference", value=False) | |
inputs=[staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed] | |
ux = gradio.Interface(fn=diffuse, title="View Diffusion", inputs=inputs, outputs=outputImage, live=True) | |
ux.launch() |