Buğrahan Dönmez commited on
Commit
6c29834
·
1 Parent(s): c62430a

Initialize the repo

Browse files
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from diffusers import DDIMScheduler,DiffusionPipeline
4
+ import torch.nn.functional as F
5
+ import cv2
6
+ from torchvision.utils import save_image
7
+ from diffusers.utils import load_image
8
+ from torchvision.transforms.functional import to_tensor, gaussian_blur
9
+ from matplotlib import pyplot as plt
10
+ import gradio as gr
11
+ import spaces
12
+ from gradio_imageslider import ImageSlider
13
+ from torchvision.transforms.functional import to_pil_image, to_tensor
14
+ from PIL import ImageFilter
15
+
16
+
17
+ def preprocess_image(input_image, device):
18
+ image = to_tensor(input_image)
19
+ image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
20
+ if image.shape[1] != 3:
21
+ image = image.expand(-1, 3, -1, -1)
22
+ image = F.interpolate(image, (1024, 1024))
23
+ image = image.to(dtype).to(device)
24
+
25
+ return image
26
+
27
+
28
+ def preprocess_mask(input_mask, device):
29
+ mask = to_tensor(input_mask.convert('L'))
30
+ mask = mask.unsqueeze_(0).float() # 0 or 1
31
+ mask = F.interpolate(mask, (1024, 1024))
32
+ mask = gaussian_blur(mask, kernel_size=(77, 77))
33
+ mask[mask < 0.1] = 0
34
+ mask[mask >= 0.1] = 1
35
+ mask = mask.to(dtype).to(device)
36
+
37
+ return mask
38
+
39
+
40
+ def make_redder(img, mask, increase_factor=0.4):
41
+ img_redder = img.clone()
42
+ mask_expanded = mask.expand_as(img)
43
+ img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
44
+
45
+ return img_redder
46
+
47
+
48
+ # Model loading parameters
49
+ is_cpu_offload_enabled = False
50
+ is_attention_slicing_enabled = True
51
+
52
+ # Load model
53
+ dtype = torch.float16
54
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
55
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
56
+
57
+ model_path = "stabilityai/stable-diffusion-xl-base-1.0"
58
+ pipeline = DiffusionPipeline.from_pretrained(
59
+ model_path,
60
+ custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser.py",
61
+ scheduler=scheduler,
62
+ variant="fp16",
63
+ use_safetensors=True,
64
+ torch_dtype=dtype,
65
+ ).to(device)
66
+
67
+ if is_attention_slicing_enabled:
68
+ pipeline.enable_attention_slicing()
69
+
70
+ if is_cpu_offload_enabled:
71
+ pipeline.enable_model_cpu_offload()
72
+
73
+
74
+ @spaces.GPU
75
+ def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
76
+ generator = torch.Generator('cuda').manual_seed(seed)
77
+ prompt = "" # Set prompt to null
78
+
79
+ source_image_pure = gradio_image["background"]
80
+ mask_image_pure = gradio_image["layers"][0]
81
+ source_image = preprocess_image(source_image_pure, device)
82
+ mask = preprocess_mask(mask_image_pure, device)
83
+
84
+ START_STEP = 0 # AAS start step
85
+ END_STEP = int(strength * num_inference_steps) # AAS end step
86
+ LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer
87
+ END_LAYER = 70 # AAS end layer
88
+ ss_steps = 9 # similarity suppression steps
89
+ ss_scale = 0.3 # similarity suppression scale
90
+
91
+ image = pipeline(
92
+ prompt=prompt,
93
+ image=source_image,
94
+ mask_image=mask,
95
+ height=1024,
96
+ width=1024,
97
+ AAS=True, # enable AAS
98
+ strength=strength, # inpainting strength
99
+ rm_guidance_scale=rm_guidance_scale, # removal guidance scale
100
+ ss_steps = ss_steps, # similarity suppression steps
101
+ ss_scale = ss_scale, # similarity suppression scale
102
+ AAS_start_step=START_STEP, # AAS start step
103
+ AAS_start_layer=LAYER, # AAS start layer
104
+ AAS_end_layer=END_LAYER, # AAS end layer
105
+ num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
106
+ generator=g,
107
+ guidance_scale=1,
108
+ output_type='pt'
109
+ ).images[0]
110
+
111
+ img = (source_image * 0.5 + 0.5).squeeze(0)
112
+ mask_red = mask.squeeze(0)
113
+ img_redder = make_redder(img, mask_red)
114
+
115
+ pil_mask = to_pil_image(mask.squeeze(0))
116
+ pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
117
+ mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
118
+ mask_f = 1-(1 - mask) * (1 - mask_blurred)
119
+
120
+ image_1 = image.unsqueeze(0)
121
+
122
+ return source_image, image_1
123
+
124
+
125
+ title = """<h1 align="center">Object Remove</h1>"""
126
+ with gr.Blocks() as demo:
127
+ gr.HTML(title)
128
+ with gr.Row():
129
+ with gr.Column():
130
+ with gr.Accordion("Advanced Options", open=False):
131
+ guidance_scale = gr.Slider(
132
+ minimum=1,
133
+ maximum=20,
134
+ value=9,
135
+ step=0.1,
136
+ label="Guidance Scale"
137
+ )
138
+ num_steps = gr.Slider(
139
+ minimum=5,
140
+ maximum=100,
141
+ value=50,
142
+ step=1,
143
+ label="Steps"
144
+ )
145
+ seed = gr.Slider(
146
+ minimum=42,
147
+ maximum=100000000000,
148
+ value=42,
149
+ step=1,
150
+ label="Seed"
151
+ )
152
+ strength = gr.Slider(
153
+ minimum=0,
154
+ maximum=1,
155
+ value=0.8,
156
+ step=0.1,
157
+ label="Strength"
158
+ )
159
+
160
+ input_image = gr.ImageMask(
161
+ type="pil", label="Input Image",crop_size=(1200,1200), layers=False
162
+ )
163
+ with gr.Column():
164
+ with gr.Row():
165
+ with gr.Column():
166
+ run_button = gr.Button("Generate")
167
+
168
+ result = ImageSlider(
169
+ interactive=False,
170
+ label="Generated Image",
171
+ type="pil"
172
+ )
173
+
174
+ run_button.click(
175
+ fn=remove,
176
+ inputs=[input_image, guidance_scale, num_steps, seed, strength],
177
+ outputs=result,
178
+ )
pipeline_stable_diffusion_xl_attentive_eraser.py ADDED
The diff for this file is too large to render. See raw diff