realantonvoronov commited on
Commit
1dc27f0
·
1 Parent(s): 484ca0e

update for 1024

Browse files
Files changed (5) hide show
  1. app.py +8 -8
  2. models/helpers.py +7 -0
  3. models/pipeline.py +21 -15
  4. models/switti.py +4 -5
  5. models/vqvae.py +3 -2
app.py CHANGED
@@ -1,16 +1,16 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
4
 
 
 
5
  import spaces
6
- from models import SwittiPipeline
7
  import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "yresearch/Switti"
11
 
12
 
13
- pipe = SwittiPipeline.from_pretrained(model_repo_id, device=device)
 
 
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
 
@@ -140,9 +140,9 @@ with gr.Blocks(css=css) as demo:
140
  turn_off_cfg_start_si = gr.Slider(
141
  label="Disable CFG starting scale",
142
  minimum=0,
143
- maximum=10,
144
  step=1,
145
- value=8,
146
  )
147
  with gr.Row():
148
  more_diverse = gr.Checkbox(label="More diverse", value=False)
 
 
 
1
  import random
2
 
3
+ import gradio as gr
4
+ import numpy as np
5
  import spaces
 
6
  import torch
7
 
8
+ from models import SwittiPipeline
 
9
 
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model_repo_id = "yresearch/Switti-1024"
13
+ pipe = SwittiPipeline.from_pretrained(model_repo_id, device=device, torch_dtype=torch.bfloat16)
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
 
 
140
  turn_off_cfg_start_si = gr.Slider(
141
  label="Disable CFG starting scale",
142
  minimum=0,
143
+ maximum=14,
144
  step=1,
145
+ value=11,
146
  )
147
  with gr.Row():
148
  more_diverse = gr.Checkbox(label="More diverse", value=False)
models/helpers.py CHANGED
@@ -3,6 +3,13 @@ from torch import nn as nn
3
  from torch.nn import functional as F
4
 
5
 
 
 
 
 
 
 
 
6
  def sample_with_top_k_top_p_(
7
  logits_BlV: torch.Tensor,
8
  top_k: int = 0,
 
3
  from torch.nn import functional as F
4
 
5
 
6
+ RESOLUTION_PATCH_NUMS_MAPPING = {
7
+ 256: "1_2_3_4_5_6_8_10_13_16",
8
+ 512: "1_2_3_4_6_9_13_18_24_32",
9
+ 1024: "1_2_3_4_5_7_9_12_16_21_27_36_48_64",
10
+ }
11
+
12
+
13
  def sample_with_top_k_top_p_(
14
  logits_BlV: torch.Tensor,
15
  top_k: int = 0,
models/pipeline.py CHANGED
@@ -8,14 +8,16 @@ from models.switti import SwittiHF, get_crop_condition
8
  from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
9
 
10
 
 
 
11
  class SwittiPipeline:
12
  vae_path = "yresearch/VQVAE-Switti"
13
  text_encoder_path = "openai/clip-vit-large-patch14"
14
  text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
 
16
- def __init__(self, switti, vae, text_encoder, text_encoder_2, device,
17
- dtype=torch.bfloat16,
18
- ):
19
  self.switti = switti.to(dtype)
20
  self.vae = vae.to(dtype)
21
  self.text_encoder = text_encoder.to(dtype)
@@ -27,13 +29,18 @@ class SwittiPipeline:
27
  self.device = device
28
 
29
  @classmethod
30
- def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
31
- switti = SwittiHF.from_pretrained(pretrained_model_name_or_path, device=device).to(device)
32
- vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
 
 
 
 
 
33
  text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
34
  text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
35
 
36
- return cls(switti, vae, text_encoder, text_encoder_2, device)
37
 
38
  @staticmethod
39
  def to_image(tensor):
@@ -84,7 +91,7 @@ class SwittiPipeline:
84
  prompt: str | list[str],
85
  null_prompt: str = "",
86
  seed: int | None = None,
87
- cfg: float = 4.0,
88
  top_k: int = 400,
89
  top_p: float = 0.95,
90
  more_smooth: bool = False,
@@ -92,8 +99,7 @@ class SwittiPipeline:
92
  smooth_start_si: int = 0,
93
  turn_off_cfg_start_si: int = 10,
94
  turn_on_cfg_start_si: int = 0,
95
- image_size: tuple[int, int] = (512, 512),
96
- last_scale_temp: float = 1.,
97
  ) -> torch.Tensor | list[PILImage]:
98
  """
99
  only used for inference, on autoregressive mode
@@ -122,8 +128,8 @@ class SwittiPipeline:
122
  cond_vector = switti.text_pooler(cond_vector)
123
 
124
  if switti.use_crop_cond:
125
- crop_coords = get_crop_condition(2 * B * [image_size[0]],
126
- 2 * B * [image_size[1]],
127
  ).to(cond_vector.device)
128
  crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
129
  crop_cond = switti.crop_proj(crop_embed)
@@ -169,7 +175,7 @@ class SwittiPipeline:
169
  if b.attn.caching and b.attn.cached_k is not None:
170
  b.attn.cached_k = b.attn.cached_k[:B]
171
  b.attn.cached_v = b.attn.cached_v[:B]
172
- if b.cross_attn.caching and b.cross_attn.cached_k is not None:
173
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
174
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
175
  else:
@@ -197,7 +203,7 @@ class SwittiPipeline:
197
  # default const cfg
198
  t = cfg
199
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
200
- else:
201
  logits_BlV = logits_BlV / last_scale_temp
202
 
203
  if apply_smooth and si >= smooth_start_si:
@@ -208,7 +214,7 @@ class SwittiPipeline:
208
  )
209
  h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
210
  else:
211
- # defaul nucleus sampling
212
  idx_Bl = sample_with_top_k_top_p_(
213
  logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
214
  )[:, :, 0]
 
8
  from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
9
 
10
 
11
+ TRAIN_IMAGE_SIZE = (512, 512)
12
+
13
  class SwittiPipeline:
14
  vae_path = "yresearch/VQVAE-Switti"
15
  text_encoder_path = "openai/clip-vit-large-patch14"
16
  text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
17
 
18
+ def __init__(self, switti, vae, text_encoder, text_encoder_2,
19
+ device, dtype=torch.float32,
20
+ ):
21
  self.switti = switti.to(dtype)
22
  self.vae = vae.to(dtype)
23
  self.text_encoder = text_encoder.to(dtype)
 
29
  self.device = device
30
 
31
  @classmethod
32
+ def from_pretrained(cls,
33
+ pretrained_model_name_or_path,
34
+ torch_dtype=torch.bfloat16,
35
+ device="cuda",
36
+ reso=1024,
37
+ ):
38
+ switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device)
39
+ vae = VQVAEHF.from_pretrained(cls.vae_path, reso=reso).to(device)
40
  text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
41
  text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
42
 
43
+ return cls(switti, vae, text_encoder, text_encoder_2, device, torch_dtype)
44
 
45
  @staticmethod
46
  def to_image(tensor):
 
91
  prompt: str | list[str],
92
  null_prompt: str = "",
93
  seed: int | None = None,
94
+ cfg: float = 6.,
95
  top_k: int = 400,
96
  top_p: float = 0.95,
97
  more_smooth: bool = False,
 
99
  smooth_start_si: int = 0,
100
  turn_off_cfg_start_si: int = 10,
101
  turn_on_cfg_start_si: int = 0,
102
+ last_scale_temp: None | float = None,
 
103
  ) -> torch.Tensor | list[PILImage]:
104
  """
105
  only used for inference, on autoregressive mode
 
128
  cond_vector = switti.text_pooler(cond_vector)
129
 
130
  if switti.use_crop_cond:
131
+ crop_coords = get_crop_condition(2 * B * [TRAIN_IMAGE_SIZE[0]],
132
+ 2 * B * [TRAIN_IMAGE_SIZE[1]],
133
  ).to(cond_vector.device)
134
  crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
135
  crop_cond = switti.crop_proj(crop_embed)
 
175
  if b.attn.caching and b.attn.cached_k is not None:
176
  b.attn.cached_k = b.attn.cached_k[:B]
177
  b.attn.cached_v = b.attn.cached_v[:B]
178
+ if b.cross_attn.caching and b.cross_attn.cached_k is not None:
179
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
180
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
181
  else:
 
203
  # default const cfg
204
  t = cfg
205
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
206
+ elif last_scale_temp is not None:
207
  logits_BlV = logits_BlV / last_scale_temp
208
 
209
  if apply_smooth and si >= smooth_start_si:
 
214
  )
215
  h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
216
  else:
217
+ # default nucleus sampling
218
  idx_Bl = sample_with_top_k_top_p_(
219
  logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
220
  )[:, :, 0]
models/switti.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.embeddings import GaussianFourierProjection
9
 
10
  from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
11
  from models.rope import compute_axial_cis
12
-
13
 
14
  def get_crop_condition(
15
  heights: list,
@@ -53,7 +53,6 @@ class Switti(nn.Module):
53
  use_swiglu_ffn=True,
54
  use_ar=False,
55
  use_crop_cond=True,
56
- device='cuda',
57
  ):
58
  super().__init__()
59
  # 0. hyperparameters
@@ -392,20 +391,20 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
392
  use_swiglu_ffn=True,
393
  use_ar=False,
394
  use_crop_cond=True,
395
- device='cuda',
396
  ):
397
  heads = depth
398
  width = depth * 64
 
399
  super().__init__(
400
  depth=depth,
401
  embed_dim=width,
402
  num_heads=heads,
403
- patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
404
  rope=rope,
405
  rope_theta=rope_theta,
406
  rope_size=rope_size,
407
  use_swiglu_ffn=use_swiglu_ffn,
408
  use_ar=use_ar,
409
  use_crop_cond=use_crop_cond,
410
- device=device,
411
  )
 
9
 
10
  from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
11
  from models.rope import compute_axial_cis
12
+ from models.helpers import RESOLUTION_PATCH_NUMS_MAPPING
13
 
14
  def get_crop_condition(
15
  heights: list,
 
53
  use_swiglu_ffn=True,
54
  use_ar=False,
55
  use_crop_cond=True,
 
56
  ):
57
  super().__init__()
58
  # 0. hyperparameters
 
391
  use_swiglu_ffn=True,
392
  use_ar=False,
393
  use_crop_cond=True,
394
+ reso=512,
395
  ):
396
  heads = depth
397
  width = depth * 64
398
+ patch_nums = tuple([int(x) for x in RESOLUTION_PATCH_NUMS_MAPPING[reso].split("_")])
399
  super().__init__(
400
  depth=depth,
401
  embed_dim=width,
402
  num_heads=heads,
403
+ patch_nums=patch_nums,
404
  rope=rope,
405
  rope_theta=rope_theta,
406
  rope_size=rope_size,
407
  use_swiglu_ffn=use_swiglu_ffn,
408
  use_ar=use_ar,
409
  use_crop_cond=use_crop_cond,
 
410
  )
models/vqvae.py CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import PyTorchModelHubMixin
13
 
14
  from .basic_vae import Decoder, Encoder
15
  from .quant import VectorQuantizer2
16
-
17
 
18
 
19
  class VQVAE(nn.Module):
@@ -172,8 +172,9 @@ class VQVAEHF(VQVAE, PyTorchModelHubMixin):
172
  ch=160,
173
  test_mode=True,
174
  share_quant_resi=4,
175
- v_patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
176
  ):
 
177
  super().__init__(
178
  vocab_size=vocab_size,
179
  z_channels=z_channels,
 
13
 
14
  from .basic_vae import Decoder, Encoder
15
  from .quant import VectorQuantizer2
16
+ from models.helpers import RESOLUTION_PATCH_NUMS_MAPPING
17
 
18
 
19
  class VQVAE(nn.Module):
 
172
  ch=160,
173
  test_mode=True,
174
  share_quant_resi=4,
175
+ reso=1024,
176
  ):
177
+ v_patch_nums = tuple((int(x) for x in RESOLUTION_PATCH_NUMS_MAPPING[reso].split("_")))
178
  super().__init__(
179
  vocab_size=vocab_size,
180
  z_channels=z_channels,