ShreeKanade07's picture
Update app.py
8a7f5de verified
raw
history blame
1.73 kB
import os
from PIL import Image, ImageDraw
import numpy as np
import torch
from torch import autocast
from torch.nn import functional as F
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from diffusers import StableDiffusionInpaintPipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"ShreeKanade07/Real-Image-pipeline", torch_dtype=torch.float16
)
#pipe = pipe.to("cuda")
# Define the predict function
def predict(image,mask,prompt):
prompt = prompt
image = Image.fromarray(image)
image=image.convert("RGB").resize((512, 512))
mask_image = Image.fromarray(mask)
mask_image=mask_image.convert("RGB").resize((512, 512))
strength=0.9
generator = torch.manual_seed(32)
negative_prompt="zoomed in, blurry, oversaturated, warped,artifacts,flickers"
images = pipe(prompt=prompt, image=image, mask_image=mask_image, strength=strength, negative_prompt=negative_prompt, generator=generator,num_inference_steps=20).images
return images[0]
# Create the Gradio interface
gr.Interface(
predict,
title='Stable Diffusion Sketch In-Painting',
inputs=[
gr.Image(label='Image'),
gr.Image(label='Mask'),
gr.Textbox(label='Prompt')
],
outputs=[
gr.Image(label='Output Image')
],
examples=[["IMG1.png", "IMG1_Mask.png",'Make it real one']], cache_examples=True
).launch(debug=True, share=True)