realantonvoronov commited on
Commit
b5f551b
·
1 Parent(s): be0617b

remove apply_late_temperature checkbox

Browse files
Files changed (2) hide show
  1. app.py +3 -8
  2. models/pipeline.py +2 -5
app.py CHANGED
@@ -28,8 +28,7 @@ def infer(
28
  smooth_start_si=2,
29
  turn_off_cfg_start_si=10,
30
  more_diverse=True,
31
- apply_late_temperature=False,
32
- last_scale_temp=None,
33
  progress=gr.Progress(track_tqdm=True),
34
  ):
35
  if randomize_seed:
@@ -49,7 +48,6 @@ def infer(
49
  turn_off_cfg_start_si=turn_off_cfg_start_si,
50
  turn_on_cfg_start_si=turn_on_cfg_start_si,
51
  seed=seed,
52
- apply_late_temperature=apply_late_temperature,
53
  last_scale_temp=last_scale_temp,
54
  )[0]
55
 
@@ -141,17 +139,15 @@ with gr.Blocks(css=css) as demo:
141
  )
142
  with gr.Row():
143
  more_diverse = gr.Checkbox(label="More diverse", value=True)
144
- apply_late_temperature = gr.Checkbox(label="Temperature after disabling CFG", value=False)
145
  last_scale_temp = gr.Slider(
146
- label="Late temperature value",
147
  minimum=0.1,
148
  maximum=10,
149
  step=0.1,
150
  value=0.1,
151
  )
152
 
153
-
154
- gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
155
  gr.on(
156
  triggers=[run_button.click, prompt.submit],
157
  fn=infer,
@@ -167,7 +163,6 @@ with gr.Blocks(css=css) as demo:
167
  smooth_start_si,
168
  turn_off_cfg_start_si,
169
  more_diverse,
170
- apply_late_temperature,
171
  last_scale_temp,
172
  ],
173
  outputs=[result, seed],
 
28
  smooth_start_si=2,
29
  turn_off_cfg_start_si=10,
30
  more_diverse=True,
31
+ last_scale_temp=1,
 
32
  progress=gr.Progress(track_tqdm=True),
33
  ):
34
  if randomize_seed:
 
48
  turn_off_cfg_start_si=turn_off_cfg_start_si,
49
  turn_on_cfg_start_si=turn_on_cfg_start_si,
50
  seed=seed,
 
51
  last_scale_temp=last_scale_temp,
52
  )[0]
53
 
 
139
  )
140
  with gr.Row():
141
  more_diverse = gr.Checkbox(label="More diverse", value=True)
 
142
  last_scale_temp = gr.Slider(
143
+ label="Temperature after disabling CFG",
144
  minimum=0.1,
145
  maximum=10,
146
  step=0.1,
147
  value=0.1,
148
  )
149
 
150
+ gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)# cache_mode="lazy")
 
151
  gr.on(
152
  triggers=[run_button.click, prompt.submit],
153
  fn=infer,
 
163
  smooth_start_si,
164
  turn_off_cfg_start_si,
165
  more_diverse,
 
166
  last_scale_temp,
167
  ],
168
  outputs=[result, seed],
models/pipeline.py CHANGED
@@ -93,8 +93,7 @@ class SwittiPipeline:
93
  turn_off_cfg_start_si: int = 10,
94
  turn_on_cfg_start_si: int = 0,
95
  image_size: tuple[int, int] = (512, 512),
96
- apply_late_temperature: bool = False,
97
- last_scale_temp: None | float = None,
98
  ) -> torch.Tensor | list[PILImage]:
99
  """
100
  only used for inference, on autoregressive mode
@@ -107,8 +106,6 @@ class SwittiPipeline:
107
  :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
108
  :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
109
  """
110
- if not apply_late_temperature:
111
- last_scale_temp = None
112
  assert not self.switti.training
113
  switti = self.switti
114
  vae = self.vae
@@ -200,7 +197,7 @@ class SwittiPipeline:
200
  # default const cfg
201
  t = cfg
202
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
203
- elif last_scale_temp is not None:
204
  logits_BlV = logits_BlV / last_scale_temp
205
 
206
  if apply_smooth and si >= smooth_start_si:
 
93
  turn_off_cfg_start_si: int = 10,
94
  turn_on_cfg_start_si: int = 0,
95
  image_size: tuple[int, int] = (512, 512),
96
+ last_scale_temp: float = 1.,
 
97
  ) -> torch.Tensor | list[PILImage]:
98
  """
99
  only used for inference, on autoregressive mode
 
106
  :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
107
  :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
108
  """
 
 
109
  assert not self.switti.training
110
  switti = self.switti
111
  vae = self.vae
 
197
  # default const cfg
198
  t = cfg
199
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
200
+ else:
201
  logits_BlV = logits_BlV / last_scale_temp
202
 
203
  if apply_smooth and si >= smooth_start_si: