Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,10 +4,19 @@ import torch
|
|
4 |
from PIL import Image
|
5 |
import utils
|
6 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
is_colab = utils.is_google_colab()
|
9 |
|
10 |
-
|
|
|
11 |
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
12 |
scheduler = DDIMScheduler.from_config(model_id_or_path,
|
13 |
use_auth_token=st.secrets["USER_TOKEN"],
|
@@ -15,21 +24,233 @@ if True:
|
|
15 |
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
|
16 |
use_auth_token=st.secrets["USER_TOKEN"],
|
17 |
scheduler=scheduler)
|
|
|
18 |
|
19 |
if torch.cuda.is_available():
|
20 |
pipe = pipe.to("cuda")
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
|
26 |
-
width=512, height=512, seed=0, img=None, strength=0.7
|
|
|
27 |
|
28 |
torch.manual_seed(seed)
|
29 |
|
30 |
ratio = min(height / img.height, width / img.width)
|
31 |
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
results = pipe(prompt=target_prompt,
|
34 |
source_prompt=source_prompt,
|
35 |
init_image=img,
|
@@ -64,7 +285,7 @@ with gr.Blocks(css=css) as demo:
|
|
64 |
<a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/cycle_diffusion">𧨠Pipeline doc</a> | <a href="https://arxiv.org/abs/2210.05559">π Paper link</a>
|
65 |
</p>
|
66 |
<p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/ChenWu98/0aa4fe7be80f6b45d3d055df9f14353a/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
|
67 |
-
Running on <b>{
|
68 |
</p>
|
69 |
</div>
|
70 |
"""
|
@@ -82,42 +303,58 @@ with gr.Blocks(css=css) as demo:
|
|
82 |
# ).style(grid=[1], height="auto")
|
83 |
|
84 |
with gr.Column(scale=45):
|
85 |
-
with gr.Tab("
|
86 |
with gr.Group():
|
87 |
with gr.Row():
|
88 |
source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
|
|
|
89 |
with gr.Row():
|
90 |
target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
|
91 |
-
|
92 |
-
with gr.Row():
|
93 |
-
source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
|
94 |
guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
|
95 |
-
|
96 |
with gr.Row():
|
97 |
-
num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
|
98 |
strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
|
99 |
|
100 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
101 |
width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
|
102 |
height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
|
103 |
|
104 |
with gr.Row():
|
105 |
seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
|
|
|
|
|
|
|
106 |
with gr.Row():
|
107 |
-
|
|
|
|
|
|
|
|
|
108 |
|
109 |
inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
110 |
-
width, height, seed, img, strength
|
|
|
111 |
generate.click(inference, inputs=inputs, outputs=image_out)
|
112 |
|
113 |
ex = gr.Examples(
|
114 |
[
|
115 |
-
["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100,
|
116 |
-
["
|
117 |
-
["
|
118 |
-
["A
|
|
|
|
|
|
|
|
|
|
|
119 |
],
|
120 |
-
[source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
|
|
|
|
121 |
image_out, inference, cache_examples=False)
|
122 |
|
123 |
gr.Markdown('''
|
|
|
4 |
from PIL import Image
|
5 |
import utils
|
6 |
import streamlit as st
|
7 |
+
import ptp_utils
|
8 |
+
import seq_aligner
|
9 |
+
import torch.nn.functional as nnf
|
10 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
11 |
+
import abc
|
12 |
+
|
13 |
+
LOW_RESOURCE = False
|
14 |
+
MAX_NUM_WORDS = 77
|
15 |
|
16 |
is_colab = utils.is_google_colab()
|
17 |
|
18 |
+
|
19 |
+
if False:
|
20 |
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
21 |
scheduler = DDIMScheduler.from_config(model_id_or_path,
|
22 |
use_auth_token=st.secrets["USER_TOKEN"],
|
|
|
24 |
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
|
25 |
use_auth_token=st.secrets["USER_TOKEN"],
|
26 |
scheduler=scheduler)
|
27 |
+
tokenizer = pipe.tokenizer
|
28 |
|
29 |
if torch.cuda.is_available():
|
30 |
pipe = pipe.to("cuda")
|
31 |
|
32 |
+
device_print = "GPU π₯" if torch.cuda.is_available() else "CPU π₯Ά"
|
33 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
+
|
35 |
+
|
36 |
+
class LocalBlend:
|
37 |
+
|
38 |
+
def __call__(self, x_t, attention_store):
|
39 |
+
k = 1
|
40 |
+
maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
|
41 |
+
maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
|
42 |
+
maps = torch.cat(maps, dim=1)
|
43 |
+
maps = (maps * self.alpha_layers).sum(-1).mean(1)
|
44 |
+
mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
|
45 |
+
mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
|
46 |
+
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
|
47 |
+
mask = mask.gt(self.threshold)
|
48 |
+
mask = (mask[:1] + mask[1:]).float()
|
49 |
+
x_t = x_t[:1] + mask * (x_t - x_t[:1])
|
50 |
+
return x_t
|
51 |
+
|
52 |
+
def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3):
|
53 |
+
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
|
54 |
+
for i, (prompt, words_) in enumerate(zip(prompts, words)):
|
55 |
+
if type(words_) is str:
|
56 |
+
words_ = [words_]
|
57 |
+
for word in words_:
|
58 |
+
ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
|
59 |
+
alpha_layers[i, :, :, :, :, ind] = 1
|
60 |
+
self.alpha_layers = alpha_layers.to(device)
|
61 |
+
self.threshold = threshold
|
62 |
+
|
63 |
+
|
64 |
+
class AttentionControl(abc.ABC):
|
65 |
+
|
66 |
+
def step_callback(self, x_t):
|
67 |
+
return x_t
|
68 |
+
|
69 |
+
def between_steps(self):
|
70 |
+
return
|
71 |
+
|
72 |
+
@property
|
73 |
+
def num_uncond_att_layers(self):
|
74 |
+
return self.num_att_layers if LOW_RESOURCE else 0
|
75 |
+
|
76 |
+
@abc.abstractmethod
|
77 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
81 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
82 |
+
if LOW_RESOURCE:
|
83 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
84 |
+
else:
|
85 |
+
h = attn.shape[0]
|
86 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
87 |
+
self.cur_att_layer += 1
|
88 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
89 |
+
self.cur_att_layer = 0
|
90 |
+
self.cur_step += 1
|
91 |
+
self.between_steps()
|
92 |
+
return attn
|
93 |
+
|
94 |
+
def reset(self):
|
95 |
+
self.cur_step = 0
|
96 |
+
self.cur_att_layer = 0
|
97 |
+
|
98 |
+
def __init__(self):
|
99 |
+
self.cur_step = 0
|
100 |
+
self.num_att_layers = -1
|
101 |
+
self.cur_att_layer = 0
|
102 |
+
|
103 |
+
|
104 |
+
class EmptyControl(AttentionControl):
|
105 |
+
|
106 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
107 |
+
return attn
|
108 |
+
|
109 |
+
|
110 |
+
class AttentionStore(AttentionControl):
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def get_empty_store():
|
114 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
115 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
116 |
+
|
117 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
118 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
119 |
+
if attn.shape[1] <= 32 ** 2: # avoid memory overhead
|
120 |
+
self.step_store[key].append(attn)
|
121 |
+
return attn
|
122 |
+
|
123 |
+
def between_steps(self):
|
124 |
+
if len(self.attention_store) == 0:
|
125 |
+
self.attention_store = self.step_store
|
126 |
+
else:
|
127 |
+
for key in self.attention_store:
|
128 |
+
for i in range(len(self.attention_store[key])):
|
129 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
130 |
+
self.step_store = self.get_empty_store()
|
131 |
+
|
132 |
+
def get_average_attention(self):
|
133 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
|
134 |
+
return average_attention
|
135 |
+
|
136 |
+
def reset(self):
|
137 |
+
super(AttentionStore, self).reset()
|
138 |
+
self.step_store = self.get_empty_store()
|
139 |
+
self.attention_store = {}
|
140 |
+
|
141 |
+
def __init__(self):
|
142 |
+
super(AttentionStore, self).__init__()
|
143 |
+
self.step_store = self.get_empty_store()
|
144 |
+
self.attention_store = {}
|
145 |
+
|
146 |
+
|
147 |
+
class AttentionControlEdit(AttentionStore, abc.ABC):
|
148 |
+
|
149 |
+
def step_callback(self, x_t):
|
150 |
+
if self.local_blend is not None:
|
151 |
+
x_t = self.local_blend(x_t, self.attention_store)
|
152 |
+
return x_t
|
153 |
+
|
154 |
+
def replace_self_attention(self, attn_base, att_replace):
|
155 |
+
if att_replace.shape[2] <= 16 ** 2:
|
156 |
+
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
|
157 |
+
else:
|
158 |
+
return att_replace
|
159 |
+
|
160 |
+
@abc.abstractmethod
|
161 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
162 |
+
raise NotImplementedError
|
163 |
+
|
164 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
165 |
+
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
|
166 |
+
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
|
167 |
+
h = attn.shape[0] // self.batch_size
|
168 |
+
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
|
169 |
+
attn_base, attn_repalce = attn[0], attn[1:]
|
170 |
+
if is_cross:
|
171 |
+
alpha_words = self.cross_replace_alpha[self.cur_step]
|
172 |
+
attn_replace_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
|
173 |
+
attn[1:] = attn_replace_new
|
174 |
+
else:
|
175 |
+
attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
|
176 |
+
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
|
177 |
+
return attn
|
178 |
+
|
179 |
+
def __init__(self, prompts, num_steps: int,
|
180 |
+
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
181 |
+
self_replace_steps: Union[float, Tuple[float, float]],
|
182 |
+
local_blend: Optional[LocalBlend]):
|
183 |
+
super(AttentionControlEdit, self).__init__()
|
184 |
+
self.batch_size = len(prompts)
|
185 |
+
self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
|
186 |
+
if type(self_replace_steps) is float:
|
187 |
+
self_replace_steps = 0, self_replace_steps
|
188 |
+
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
|
189 |
+
self.local_blend = local_blend
|
190 |
+
|
191 |
+
|
192 |
+
class AttentionReplace(AttentionControlEdit):
|
193 |
+
|
194 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
195 |
+
return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
|
196 |
+
|
197 |
+
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
198 |
+
local_blend: Optional[LocalBlend] = None):
|
199 |
+
super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
|
200 |
+
self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
|
201 |
+
|
202 |
+
|
203 |
+
class AttentionRefine(AttentionControlEdit):
|
204 |
+
|
205 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
206 |
+
attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
|
207 |
+
attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
|
208 |
+
return attn_replace
|
209 |
+
|
210 |
+
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
|
211 |
+
local_blend: Optional[LocalBlend] = None):
|
212 |
+
super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
|
213 |
+
self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
|
214 |
+
self.mapper, alphas = self.mapper.to(device), alphas.to(device)
|
215 |
+
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
|
216 |
+
|
217 |
+
|
218 |
+
def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]]):
|
219 |
+
if type(word_select) is int or type(word_select) is str:
|
220 |
+
word_select = (word_select,)
|
221 |
+
equalizer = torch.ones(len(values), 77)
|
222 |
+
values = torch.tensor(values, dtype=torch.float32)
|
223 |
+
for word in word_select:
|
224 |
+
inds = ptp_utils.get_word_inds(text, word, tokenizer)
|
225 |
+
equalizer[:, inds] = values
|
226 |
+
return equalizer
|
227 |
|
228 |
|
229 |
def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
|
230 |
+
width=512, height=512, seed=0, img=None, strength=0.7,
|
231 |
+
cross_attention_control=None, cross_replace_steps=0.8, self_replace_steps=0.4):
|
232 |
|
233 |
torch.manual_seed(seed)
|
234 |
|
235 |
ratio = min(height / img.height, width / img.width)
|
236 |
img = img.resize((int(img.width * ratio), int(img.height * ratio)))
|
237 |
|
238 |
+
# create the CAC controller.
|
239 |
+
if cross_attention_control == "replace":
|
240 |
+
controller = AttentionReplace([source_prompt, target_prompt],
|
241 |
+
num_inference_steps,
|
242 |
+
cross_replace_steps=cross_replace_steps,
|
243 |
+
self_replace_steps=self_replace_steps,
|
244 |
+
)
|
245 |
+
ptp_utils.register_attention_control(pipe, controller)
|
246 |
+
elif cross_attention_control == "refine":
|
247 |
+
controller = AttentionRefine([source_prompt, target_prompt],
|
248 |
+
num_inference_steps,
|
249 |
+
cross_replace_steps=cross_replace_steps,
|
250 |
+
self_replace_steps=self_replace_steps,
|
251 |
+
)
|
252 |
+
ptp_utils.register_attention_control(pipe, controller)
|
253 |
+
|
254 |
results = pipe(prompt=target_prompt,
|
255 |
source_prompt=source_prompt,
|
256 |
init_image=img,
|
|
|
285 |
<a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/cycle_diffusion">𧨠Pipeline doc</a> | <a href="https://arxiv.org/abs/2210.05559">π Paper link</a>
|
286 |
</p>
|
287 |
<p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/ChenWu98/0aa4fe7be80f6b45d3d055df9f14353a/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
|
288 |
+
Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
|
289 |
</p>
|
290 |
</div>
|
291 |
"""
|
|
|
303 |
# ).style(grid=[1], height="auto")
|
304 |
|
305 |
with gr.Column(scale=45):
|
306 |
+
with gr.Tab("Edit options"):
|
307 |
with gr.Group():
|
308 |
with gr.Row():
|
309 |
source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
|
310 |
+
source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
|
311 |
with gr.Row():
|
312 |
target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
|
|
|
|
|
|
|
313 |
guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
|
|
|
314 |
with gr.Row():
|
|
|
315 |
strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
|
316 |
|
317 |
with gr.Row():
|
318 |
+
generate = gr.Button(value="Edit")
|
319 |
+
with gr.Tab("Basic options"):
|
320 |
+
with gr.Group():
|
321 |
+
with gr.Row():
|
322 |
+
num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
|
323 |
width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
|
324 |
height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
|
325 |
|
326 |
with gr.Row():
|
327 |
seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
|
328 |
+
|
329 |
+
with gr.Tab("CAC options"):
|
330 |
+
with gr.Group():
|
331 |
with gr.Row():
|
332 |
+
cross_attention_control = gr.Radio(label="CAC type", choices=["None", "Replace", "Refine"], value="None")
|
333 |
+
with gr.Row():
|
334 |
+
# If not "None", the following two parameters will be used.
|
335 |
+
cross_replace_steps = gr.Slider(label="Cross replace steps", value=0.8, minimum=0.0, maximum=1, step=0.01)
|
336 |
+
self_replace_steps = gr.Slider(label="Self replace steps", value=0.4, minimum=0.0, maximum=1, step=0.01)
|
337 |
|
338 |
inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
339 |
+
width, height, seed, img, strength,
|
340 |
+
cross_attention_control, cross_replace_steps, self_replace_steps]
|
341 |
generate.click(inference, inputs=inputs, outputs=image_out)
|
342 |
|
343 |
ex = gr.Examples(
|
344 |
[
|
345 |
+
["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100, "images/astronaut_horse.png", 0.8, "None", 0, 0],
|
346 |
+
["An astronaut riding a horse", "An astronaut riding a elephant", 1, 2, 100, "images/astronaut_horse.png", 0.9, "Replace", 0.15, 0.10],
|
347 |
+
["A black colored car.", "A blue colored car.", 1, 2, 100, "images/black_car.png", 0.85, "None", 0, 0],
|
348 |
+
["A black colored car.", "A blue colored car.", 1, 5, 100, "images/black_car.png", 0.95, "Replace", 0.8, 0.4],
|
349 |
+
["A black colored car.", "A red colored car.", 1, 5, 100, "images/black_car.png", 1, "Replace", 0.8, 0.4],
|
350 |
+
["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, "images/mausoleum.png", 0.9, "None", 0.0, 0.0],
|
351 |
+
["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, "images/mausoleum.png", 1, "Replace", 0.8, 0.4],
|
352 |
+
["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, "images/apple_bag.png", 0.9, "None", 0.0, 0.0],
|
353 |
+
["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, "images/apple_bag.png", 0.9, "Replace", 0.8, 0.4],
|
354 |
],
|
355 |
+
[source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
|
356 |
+
img, strength,
|
357 |
+
cross_attention_control, cross_replace_steps, self_replace_steps],
|
358 |
image_out, inference, cache_examples=False)
|
359 |
|
360 |
gr.Markdown('''
|