Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
·
1e17711
1
Parent(s):
e5b0112
update sampling paramters in pipeline and arguments in app
Browse files- app.py +22 -2
- models/pipeline.py +13 -3
app.py
CHANGED
@@ -27,11 +27,16 @@ def infer(
|
|
27 |
more_smooth=True,
|
28 |
smooth_start_si=2,
|
29 |
turn_off_cfg_start_si=10,
|
|
|
|
|
30 |
progress=gr.Progress(track_tqdm=True),
|
31 |
):
|
32 |
if randomize_seed:
|
33 |
seed = random.randint(0, MAX_SEED)
|
34 |
|
|
|
|
|
|
|
35 |
image = pipe(
|
36 |
prompt=prompt,
|
37 |
null_prompt=negative_prompt,
|
@@ -41,6 +46,7 @@ def infer(
|
|
41 |
more_smooth=more_smooth,
|
42 |
smooth_start_si=smooth_start_si,
|
43 |
turn_off_cfg_start_si=turn_off_cfg_start_si,
|
|
|
44 |
seed=seed,
|
45 |
)[0]
|
46 |
|
@@ -103,7 +109,7 @@ with gr.Blocks(css=css) as demo:
|
|
103 |
minimum=0.0,
|
104 |
maximum=10.,
|
105 |
step=0.5,
|
106 |
-
value=
|
107 |
)
|
108 |
|
109 |
with gr.Accordion("Advanced Settings", open=False):
|
@@ -140,12 +146,24 @@ with gr.Blocks(css=css) as demo:
|
|
140 |
value=2,
|
141 |
)
|
142 |
turn_off_cfg_start_si = gr.Slider(
|
143 |
-
label="Disable CFG
|
144 |
minimum=0,
|
145 |
maximum=10,
|
146 |
step=1,
|
147 |
value=8,
|
148 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
|
151 |
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
|
@@ -163,6 +181,8 @@ with gr.Blocks(css=css) as demo:
|
|
163 |
more_smooth,
|
164 |
smooth_start_si,
|
165 |
turn_off_cfg_start_si,
|
|
|
|
|
166 |
],
|
167 |
outputs=[result, seed],
|
168 |
)
|
|
|
27 |
more_smooth=True,
|
28 |
smooth_start_si=2,
|
29 |
turn_off_cfg_start_si=10,
|
30 |
+
more_diverse=True,
|
31 |
+
last_scale_temp=None,
|
32 |
progress=gr.Progress(track_tqdm=True),
|
33 |
):
|
34 |
if randomize_seed:
|
35 |
seed = random.randint(0, MAX_SEED)
|
36 |
|
37 |
+
|
38 |
+
turn_on_cfg_start_si = 2 if more_diverse else 0
|
39 |
+
|
40 |
image = pipe(
|
41 |
prompt=prompt,
|
42 |
null_prompt=negative_prompt,
|
|
|
46 |
more_smooth=more_smooth,
|
47 |
smooth_start_si=smooth_start_si,
|
48 |
turn_off_cfg_start_si=turn_off_cfg_start_si,
|
49 |
+
turn_on_cfg_start_si=turn_on_cfg_start_si,
|
50 |
seed=seed,
|
51 |
)[0]
|
52 |
|
|
|
109 |
minimum=0.0,
|
110 |
maximum=10.,
|
111 |
step=0.5,
|
112 |
+
value=6.,
|
113 |
)
|
114 |
|
115 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
146 |
value=2,
|
147 |
)
|
148 |
turn_off_cfg_start_si = gr.Slider(
|
149 |
+
label="Disable CFG starting scale",
|
150 |
minimum=0,
|
151 |
maximum=10,
|
152 |
step=1,
|
153 |
value=8,
|
154 |
)
|
155 |
+
with gr.Row():
|
156 |
+
more_diverse = gr.Checkbox(label="More diverse", value=True)
|
157 |
+
apply_late_temperature = gr.Checkbox(label="Temperature after disabling CFG", value=False)
|
158 |
+
last_scale_temp = gr.Slider(
|
159 |
+
label="Late temperature value",
|
160 |
+
minimum=0.1,
|
161 |
+
maximum=10,
|
162 |
+
step=0.1,
|
163 |
+
value=1,
|
164 |
+
)
|
165 |
+
if not apply_late_temperature:
|
166 |
+
last_scale_temp = None
|
167 |
|
168 |
|
169 |
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
|
|
|
181 |
more_smooth,
|
182 |
smooth_start_si,
|
183 |
turn_off_cfg_start_si,
|
184 |
+
more_diverse,
|
185 |
+
last_scale_temp,
|
186 |
],
|
187 |
outputs=[result, seed],
|
188 |
)
|
models/pipeline.py
CHANGED
@@ -91,7 +91,9 @@ class SwittiPipeline:
|
|
91 |
return_pil: bool = True,
|
92 |
smooth_start_si: int = 0,
|
93 |
turn_off_cfg_start_si: int = 10,
|
|
|
94 |
image_size: tuple[int, int] = (512, 512),
|
|
|
95 |
) -> torch.Tensor | list[PILImage]:
|
96 |
"""
|
97 |
only used for inference, on autoregressive mode
|
@@ -155,7 +157,8 @@ class SwittiPipeline:
|
|
155 |
else:
|
156 |
freqs_cis = switti.freqs_cis
|
157 |
|
158 |
-
if si >= turn_off_cfg_start_si:
|
|
|
159 |
x_BLC = x_BLC[:B]
|
160 |
context = context[:B]
|
161 |
context_attn_bias = context_attn_bias[:B]
|
@@ -170,6 +173,8 @@ class SwittiPipeline:
|
|
170 |
if b.cross_attn.caching and b.cross_attn.cached_k is not None:
|
171 |
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
|
172 |
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
|
|
|
|
|
173 |
|
174 |
for block in switti.blocks:
|
175 |
x_BLC = block(
|
@@ -186,11 +191,16 @@ class SwittiPipeline:
|
|
186 |
logits_BlV = switti.get_logits(x_BLC, cond_BD)
|
187 |
|
188 |
# Guidance
|
189 |
-
if si <
|
|
|
|
|
|
|
190 |
t = cfg
|
191 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
|
|
|
|
192 |
|
193 |
-
if
|
194 |
# not used when evaluating FID/IS/Precision/Recall
|
195 |
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
196 |
idx_Bl = gumbel_softmax_with_rng(
|
|
|
91 |
return_pil: bool = True,
|
92 |
smooth_start_si: int = 0,
|
93 |
turn_off_cfg_start_si: int = 10,
|
94 |
+
turn_on_cfg_start_si: int = 0,
|
95 |
image_size: tuple[int, int] = (512, 512),
|
96 |
+
last_scale_temp: None | float = None,
|
97 |
) -> torch.Tensor | list[PILImage]:
|
98 |
"""
|
99 |
only used for inference, on autoregressive mode
|
|
|
157 |
else:
|
158 |
freqs_cis = switti.freqs_cis
|
159 |
|
160 |
+
if si < turn_on_cfg_start_si or si >= turn_off_cfg_start_si:
|
161 |
+
apply_smooth = False
|
162 |
x_BLC = x_BLC[:B]
|
163 |
context = context[:B]
|
164 |
context_attn_bias = context_attn_bias[:B]
|
|
|
173 |
if b.cross_attn.caching and b.cross_attn.cached_k is not None:
|
174 |
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
|
175 |
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
|
176 |
+
else:
|
177 |
+
apply_smooth = more_smooth
|
178 |
|
179 |
for block in switti.blocks:
|
180 |
x_BLC = block(
|
|
|
191 |
logits_BlV = switti.get_logits(x_BLC, cond_BD)
|
192 |
|
193 |
# Guidance
|
194 |
+
if si < turn_on_cfg_start_si:
|
195 |
+
t = 0 # no guidance
|
196 |
+
elif si >= turn_on_cfg_start_si and si < turn_off_cfg_start_si:
|
197 |
+
# default const cfg
|
198 |
t = cfg
|
199 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
200 |
+
elif last_scale_temp is not None:
|
201 |
+
logits_BlV = logits_BlV / last_scale_temp
|
202 |
|
203 |
+
if apply_smooth and si >= smooth_start_si:
|
204 |
# not used when evaluating FID/IS/Precision/Recall
|
205 |
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
206 |
idx_Bl = gumbel_softmax_with_rng(
|