Linoy Tsaban
commited on
Commit
•
c633c03
1
Parent(s):
9cd1904
Update app.py
Browse filescaption image with BLIP
app.py
CHANGED
@@ -4,16 +4,37 @@ import numpy as np
|
|
4 |
import requests
|
5 |
import random
|
6 |
from io import BytesIO
|
7 |
-
from diffusers import StableDiffusionPipeline
|
8 |
-
from diffusers import DDIMScheduler
|
9 |
from utils import *
|
10 |
from inversion_utils import *
|
11 |
from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
12 |
from torch import autocast, inference_mode
|
13 |
-
import
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
18 |
|
19 |
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
@@ -35,7 +56,6 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
|
|
35 |
return zs, wts
|
36 |
|
37 |
|
38 |
-
|
39 |
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
40 |
|
41 |
# reverse process (via Zs and wT)
|
@@ -49,85 +69,13 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
|
49 |
img = image_grid(x0_dec)
|
50 |
return img
|
51 |
|
52 |
-
# load pipelines
|
53 |
-
sd_model_id = "stabilityai/stable-diffusion-2-base"
|
54 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
-
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
|
56 |
-
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
|
57 |
-
sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
|
58 |
-
|
59 |
-
|
60 |
-
def get_example():
|
61 |
-
case = [
|
62 |
-
[
|
63 |
-
'examples/source_a_cat_sitting_next_to_a_mirror.jpeg',
|
64 |
-
'a cat sitting next to a mirror',
|
65 |
-
'watercolor painting of a cat sitting next to a mirror',
|
66 |
-
100,
|
67 |
-
36,
|
68 |
-
15,
|
69 |
-
'Schnauzer dog', 'cat',
|
70 |
-
5.5,
|
71 |
-
1,
|
72 |
-
'examples/ddpm_sega_watercolor_painting_a_cat_sitting_next_to_a_mirror_plus_dog_minus_cat.png'
|
73 |
-
],
|
74 |
-
[
|
75 |
-
'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
|
76 |
-
'a man wearing a brown hoodie in a crowded street',
|
77 |
-
'a robot wearing a brown hoodie in a crowded street',
|
78 |
-
100,
|
79 |
-
36,
|
80 |
-
15,
|
81 |
-
'painting','',
|
82 |
-
10,
|
83 |
-
1,
|
84 |
-
'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
|
85 |
-
],
|
86 |
-
[
|
87 |
-
'examples/source_wall_with_framed_photos.jpeg',
|
88 |
-
'',
|
89 |
-
'',
|
90 |
-
100,
|
91 |
-
36,
|
92 |
-
15,
|
93 |
-
'pink drawings of muffins','',
|
94 |
-
10,
|
95 |
-
1,
|
96 |
-
'examples/ddpm_sega_plus_pink_drawings_of_muffins.png'
|
97 |
-
],
|
98 |
-
[
|
99 |
-
'examples/source_an_empty_room_with_concrete_walls.jpg',
|
100 |
-
'an empty room with concrete walls',
|
101 |
-
'glass walls',
|
102 |
-
100,
|
103 |
-
36,
|
104 |
-
17,
|
105 |
-
'giant elephant','',
|
106 |
-
10,
|
107 |
-
1,
|
108 |
-
'examples/ddpm_sega_glass_walls_gian_elephant.png'
|
109 |
-
]]
|
110 |
-
return case
|
111 |
-
|
112 |
-
def randomize_seed_fn(seed, randomize_seed):
|
113 |
-
if randomize_seed:
|
114 |
-
seed = random.randint(0, np.iinfo(np.int32).max)
|
115 |
-
torch.manual_seed(seed)
|
116 |
-
return seed
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
|
121 |
def reconstruct(tar_prompt,
|
122 |
tar_cfg_scale,
|
123 |
skip,
|
124 |
wts, zs,
|
125 |
-
# do_reconstruction,
|
126 |
-
# reconstruction
|
127 |
):
|
128 |
|
129 |
-
|
130 |
-
# if do_reconstruction:
|
131 |
reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
132 |
return reconstruction
|
133 |
|
@@ -158,6 +106,7 @@ def load_and_invert(
|
|
158 |
|
159 |
return wts, zs, do_inversion
|
160 |
|
|
|
161 |
|
162 |
def edit(input_image,
|
163 |
wts, zs,
|
@@ -197,6 +146,66 @@ def edit(input_image,
|
|
197 |
|
198 |
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
########
|
202 |
# demo #
|
@@ -346,6 +355,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
346 |
|
347 |
|
348 |
with gr.Row():
|
|
|
349 |
run_button = gr.Button("Run")
|
350 |
reconstruct_button = gr.Button("Show Reconstruction", visible=False)
|
351 |
|
@@ -366,11 +376,14 @@ with gr.Blocks(css='style.css') as demo:
|
|
366 |
with gr.Accordion("Help", open=False):
|
367 |
gr.Markdown(help_text)
|
368 |
|
369 |
-
|
|
|
|
|
|
|
|
|
370 |
|
371 |
add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
|
372 |
outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
|
373 |
-
|
374 |
|
375 |
run_button.click(
|
376 |
fn = randomize_seed_fn,
|
|
|
4 |
import requests
|
5 |
import random
|
6 |
from io import BytesIO
|
|
|
|
|
7 |
from utils import *
|
8 |
from inversion_utils import *
|
9 |
from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
10 |
from torch import autocast, inference_mode
|
11 |
+
from diffusers import StableDiffusionPipeline
|
12 |
+
from diffusers import DDIMScheduler
|
13 |
+
from transformers import AutoProcessor, BlipForConditionalGeneration
|
14 |
|
15 |
+
# load pipelines
|
16 |
+
sd_model_id = "stabilityai/stable-diffusion-2-base"
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
|
19 |
+
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
|
20 |
+
sem_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
|
21 |
+
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
|
22 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
|
23 |
|
24 |
|
25 |
+
|
26 |
+
## IMAGE CPATIONING ##
|
27 |
+
def caption_image(input_image):
|
28 |
+
|
29 |
+
inputs = blip_processor(images=image, return_tensors="pt")
|
30 |
+
pixel_values = inputs.pixel_values
|
31 |
+
|
32 |
+
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
33 |
+
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
34 |
+
return generated_caption
|
35 |
+
|
36 |
+
|
37 |
+
## DDPM INVERSION AND SAMPLING ##
|
38 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
39 |
|
40 |
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
|
|
56 |
return zs, wts
|
57 |
|
58 |
|
|
|
59 |
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
60 |
|
61 |
# reverse process (via Zs and wT)
|
|
|
69 |
img = image_grid(x0_dec)
|
70 |
return img
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
def reconstruct(tar_prompt,
|
74 |
tar_cfg_scale,
|
75 |
skip,
|
76 |
wts, zs,
|
|
|
|
|
77 |
):
|
78 |
|
|
|
|
|
79 |
reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
80 |
return reconstruction
|
81 |
|
|
|
106 |
|
107 |
return wts, zs, do_inversion
|
108 |
|
109 |
+
## SEGA ##
|
110 |
|
111 |
def edit(input_image,
|
112 |
wts, zs,
|
|
|
146 |
|
147 |
|
148 |
|
149 |
+
def randomize_seed_fn(seed, randomize_seed):
|
150 |
+
if randomize_seed:
|
151 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
152 |
+
torch.manual_seed(seed)
|
153 |
+
return seed
|
154 |
+
|
155 |
+
|
156 |
+
def get_example():
|
157 |
+
case = [
|
158 |
+
[
|
159 |
+
'examples/source_a_cat_sitting_next_to_a_mirror.jpeg',
|
160 |
+
'a cat sitting next to a mirror',
|
161 |
+
'watercolor painting of a cat sitting next to a mirror',
|
162 |
+
100,
|
163 |
+
36,
|
164 |
+
15,
|
165 |
+
'Schnauzer dog', 'cat',
|
166 |
+
5.5,
|
167 |
+
1,
|
168 |
+
'examples/ddpm_sega_watercolor_painting_a_cat_sitting_next_to_a_mirror_plus_dog_minus_cat.png'
|
169 |
+
],
|
170 |
+
[
|
171 |
+
'examples/source_a_man_wearing_a_brown_hoodie_in_a_crowded_street.jpeg',
|
172 |
+
'a man wearing a brown hoodie in a crowded street',
|
173 |
+
'a robot wearing a brown hoodie in a crowded street',
|
174 |
+
100,
|
175 |
+
36,
|
176 |
+
15,
|
177 |
+
'painting','',
|
178 |
+
10,
|
179 |
+
1,
|
180 |
+
'examples/ddpm_sega_painting_of_a_robot_wearing_a_brown_hoodie_in_a_crowded_street.png'
|
181 |
+
],
|
182 |
+
[
|
183 |
+
'examples/source_wall_with_framed_photos.jpeg',
|
184 |
+
'',
|
185 |
+
'',
|
186 |
+
100,
|
187 |
+
36,
|
188 |
+
15,
|
189 |
+
'pink drawings of muffins','',
|
190 |
+
10,
|
191 |
+
1,
|
192 |
+
'examples/ddpm_sega_plus_pink_drawings_of_muffins.png'
|
193 |
+
],
|
194 |
+
[
|
195 |
+
'examples/source_an_empty_room_with_concrete_walls.jpg',
|
196 |
+
'an empty room with concrete walls',
|
197 |
+
'glass walls',
|
198 |
+
100,
|
199 |
+
36,
|
200 |
+
17,
|
201 |
+
'giant elephant','',
|
202 |
+
10,
|
203 |
+
1,
|
204 |
+
'examples/ddpm_sega_glass_walls_gian_elephant.png'
|
205 |
+
]]
|
206 |
+
return case
|
207 |
+
|
208 |
+
|
209 |
|
210 |
########
|
211 |
# demo #
|
|
|
355 |
|
356 |
|
357 |
with gr.Row():
|
358 |
+
caption_button = gr.Button("Caption Image")
|
359 |
run_button = gr.Button("Run")
|
360 |
reconstruct_button = gr.Button("Show Reconstruction", visible=False)
|
361 |
|
|
|
376 |
with gr.Accordion("Help", open=False):
|
377 |
gr.Markdown(help_text)
|
378 |
|
379 |
+
caption_button.click(
|
380 |
+
fn = caption_image,
|
381 |
+
inputs = [input_image],
|
382 |
+
outputs = [tar_prompt]
|
383 |
+
)
|
384 |
|
385 |
add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
|
386 |
outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
|
|
|
387 |
|
388 |
run_button.click(
|
389 |
fn = randomize_seed_fn,
|