ShreeKanade07 commited on
Commit
69205f8
·
1 Parent(s): 47aa85b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageDraw
3
+ import cv2
4
+ import numpy as np
5
+ from IPython.display import HTML
6
+ from base64 import b64encode
7
+
8
+ import torch
9
+ from torch import autocast
10
+ from torch.nn import functional as F
11
+ from diffusers import StableDiffusionPipeline, AutoencoderKL
12
+ from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler
13
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ from tqdm.auto import tqdm
16
+ import gradio as gr
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ from diffusers import StableDiffusionInpaintPipeline
21
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
22
+ "ShreeKanade07/Real-Image-pipeline", torch_dtype=torch.float16
23
+ )
24
+ #pipe = pipe.to("cuda")
25
+
26
+
27
+
28
+
29
+ # Define the predict function
30
+ def predict(image,mask,prompt):
31
+
32
+ prompt = prompt
33
+ image = Image.fromarray(image)
34
+ image=image.convert("RGB").resize((512, 512))
35
+ mask_image = Image.fromarray(mask)
36
+ mask_image=mask_image.convert("RGB").resize((512, 512))
37
+ strength=0.9
38
+ generator = torch.manual_seed(32)
39
+ negative_prompt="zoomed in, blurry, oversaturated, warped,artifacts,flickers"
40
+ images = pipe(prompt=prompt, image=image, mask_image=mask_image, strength=strength, negative_prompt=negative_prompt, generator=generator).images
41
+ return images[0]
42
+
43
+
44
+
45
+ # Create the Gradio interface
46
+ gr.Interface(
47
+ predict,
48
+ title='Stable Diffusion In-Painting',
49
+ inputs=[
50
+ gr.Image(label='Image'),
51
+ gr.Image(label='Mask'),
52
+ gr.Textbox(label='Prompt')
53
+ ],
54
+ outputs=[
55
+ gr.Image(label='Output Image')
56
+ ]
57
+ ).launch(debug=True, share=True)