Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
·
1dc27f0
1
Parent(s):
484ca0e
update for 1024
Browse files- app.py +8 -8
- models/helpers.py +7 -0
- models/pipeline.py +21 -15
- models/switti.py +4 -5
- models/vqvae.py +3 -2
app.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import numpy as np
|
3 |
import random
|
4 |
|
|
|
|
|
5 |
import spaces
|
6 |
-
from models import SwittiPipeline
|
7 |
import torch
|
8 |
|
9 |
-
|
10 |
-
model_repo_id = "yresearch/Switti"
|
11 |
|
12 |
|
13 |
-
|
|
|
|
|
14 |
|
15 |
MAX_SEED = np.iinfo(np.int32).max
|
16 |
|
@@ -140,9 +140,9 @@ with gr.Blocks(css=css) as demo:
|
|
140 |
turn_off_cfg_start_si = gr.Slider(
|
141 |
label="Disable CFG starting scale",
|
142 |
minimum=0,
|
143 |
-
maximum=
|
144 |
step=1,
|
145 |
-
value=
|
146 |
)
|
147 |
with gr.Row():
|
148 |
more_diverse = gr.Checkbox(label="More diverse", value=False)
|
|
|
|
|
|
|
1 |
import random
|
2 |
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
import spaces
|
|
|
6 |
import torch
|
7 |
|
8 |
+
from models import SwittiPipeline
|
|
|
9 |
|
10 |
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
model_repo_id = "yresearch/Switti-1024"
|
13 |
+
pipe = SwittiPipeline.from_pretrained(model_repo_id, device=device, torch_dtype=torch.bfloat16)
|
14 |
|
15 |
MAX_SEED = np.iinfo(np.int32).max
|
16 |
|
|
|
140 |
turn_off_cfg_start_si = gr.Slider(
|
141 |
label="Disable CFG starting scale",
|
142 |
minimum=0,
|
143 |
+
maximum=14,
|
144 |
step=1,
|
145 |
+
value=11,
|
146 |
)
|
147 |
with gr.Row():
|
148 |
more_diverse = gr.Checkbox(label="More diverse", value=False)
|
models/helpers.py
CHANGED
@@ -3,6 +3,13 @@ from torch import nn as nn
|
|
3 |
from torch.nn import functional as F
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
def sample_with_top_k_top_p_(
|
7 |
logits_BlV: torch.Tensor,
|
8 |
top_k: int = 0,
|
|
|
3 |
from torch.nn import functional as F
|
4 |
|
5 |
|
6 |
+
RESOLUTION_PATCH_NUMS_MAPPING = {
|
7 |
+
256: "1_2_3_4_5_6_8_10_13_16",
|
8 |
+
512: "1_2_3_4_6_9_13_18_24_32",
|
9 |
+
1024: "1_2_3_4_5_7_9_12_16_21_27_36_48_64",
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
def sample_with_top_k_top_p_(
|
14 |
logits_BlV: torch.Tensor,
|
15 |
top_k: int = 0,
|
models/pipeline.py
CHANGED
@@ -8,14 +8,16 @@ from models.switti import SwittiHF, get_crop_condition
|
|
8 |
from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
|
9 |
|
10 |
|
|
|
|
|
11 |
class SwittiPipeline:
|
12 |
vae_path = "yresearch/VQVAE-Switti"
|
13 |
text_encoder_path = "openai/clip-vit-large-patch14"
|
14 |
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
15 |
|
16 |
-
def __init__(self, switti, vae, text_encoder, text_encoder_2,
|
17 |
-
|
18 |
-
|
19 |
self.switti = switti.to(dtype)
|
20 |
self.vae = vae.to(dtype)
|
21 |
self.text_encoder = text_encoder.to(dtype)
|
@@ -27,13 +29,18 @@ class SwittiPipeline:
|
|
27 |
self.device = device
|
28 |
|
29 |
@classmethod
|
30 |
-
def from_pretrained(cls,
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
33 |
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
|
34 |
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
|
35 |
|
36 |
-
return cls(switti, vae, text_encoder, text_encoder_2, device)
|
37 |
|
38 |
@staticmethod
|
39 |
def to_image(tensor):
|
@@ -84,7 +91,7 @@ class SwittiPipeline:
|
|
84 |
prompt: str | list[str],
|
85 |
null_prompt: str = "",
|
86 |
seed: int | None = None,
|
87 |
-
cfg: float =
|
88 |
top_k: int = 400,
|
89 |
top_p: float = 0.95,
|
90 |
more_smooth: bool = False,
|
@@ -92,8 +99,7 @@ class SwittiPipeline:
|
|
92 |
smooth_start_si: int = 0,
|
93 |
turn_off_cfg_start_si: int = 10,
|
94 |
turn_on_cfg_start_si: int = 0,
|
95 |
-
|
96 |
-
last_scale_temp: float = 1.,
|
97 |
) -> torch.Tensor | list[PILImage]:
|
98 |
"""
|
99 |
only used for inference, on autoregressive mode
|
@@ -122,8 +128,8 @@ class SwittiPipeline:
|
|
122 |
cond_vector = switti.text_pooler(cond_vector)
|
123 |
|
124 |
if switti.use_crop_cond:
|
125 |
-
crop_coords = get_crop_condition(2 * B * [
|
126 |
-
2 * B * [
|
127 |
).to(cond_vector.device)
|
128 |
crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
|
129 |
crop_cond = switti.crop_proj(crop_embed)
|
@@ -169,7 +175,7 @@ class SwittiPipeline:
|
|
169 |
if b.attn.caching and b.attn.cached_k is not None:
|
170 |
b.attn.cached_k = b.attn.cached_k[:B]
|
171 |
b.attn.cached_v = b.attn.cached_v[:B]
|
172 |
-
if b.cross_attn.caching
|
173 |
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
|
174 |
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
|
175 |
else:
|
@@ -197,7 +203,7 @@ class SwittiPipeline:
|
|
197 |
# default const cfg
|
198 |
t = cfg
|
199 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
200 |
-
|
201 |
logits_BlV = logits_BlV / last_scale_temp
|
202 |
|
203 |
if apply_smooth and si >= smooth_start_si:
|
@@ -208,7 +214,7 @@ class SwittiPipeline:
|
|
208 |
)
|
209 |
h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
|
210 |
else:
|
211 |
-
#
|
212 |
idx_Bl = sample_with_top_k_top_p_(
|
213 |
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
|
214 |
)[:, :, 0]
|
|
|
8 |
from models.helpers import sample_with_top_k_top_p_, gumbel_softmax_with_rng
|
9 |
|
10 |
|
11 |
+
TRAIN_IMAGE_SIZE = (512, 512)
|
12 |
+
|
13 |
class SwittiPipeline:
|
14 |
vae_path = "yresearch/VQVAE-Switti"
|
15 |
text_encoder_path = "openai/clip-vit-large-patch14"
|
16 |
text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
17 |
|
18 |
+
def __init__(self, switti, vae, text_encoder, text_encoder_2,
|
19 |
+
device, dtype=torch.float32,
|
20 |
+
):
|
21 |
self.switti = switti.to(dtype)
|
22 |
self.vae = vae.to(dtype)
|
23 |
self.text_encoder = text_encoder.to(dtype)
|
|
|
29 |
self.device = device
|
30 |
|
31 |
@classmethod
|
32 |
+
def from_pretrained(cls,
|
33 |
+
pretrained_model_name_or_path,
|
34 |
+
torch_dtype=torch.bfloat16,
|
35 |
+
device="cuda",
|
36 |
+
reso=1024,
|
37 |
+
):
|
38 |
+
switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device)
|
39 |
+
vae = VQVAEHF.from_pretrained(cls.vae_path, reso=reso).to(device)
|
40 |
text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
|
41 |
text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
|
42 |
|
43 |
+
return cls(switti, vae, text_encoder, text_encoder_2, device, torch_dtype)
|
44 |
|
45 |
@staticmethod
|
46 |
def to_image(tensor):
|
|
|
91 |
prompt: str | list[str],
|
92 |
null_prompt: str = "",
|
93 |
seed: int | None = None,
|
94 |
+
cfg: float = 6.,
|
95 |
top_k: int = 400,
|
96 |
top_p: float = 0.95,
|
97 |
more_smooth: bool = False,
|
|
|
99 |
smooth_start_si: int = 0,
|
100 |
turn_off_cfg_start_si: int = 10,
|
101 |
turn_on_cfg_start_si: int = 0,
|
102 |
+
last_scale_temp: None | float = None,
|
|
|
103 |
) -> torch.Tensor | list[PILImage]:
|
104 |
"""
|
105 |
only used for inference, on autoregressive mode
|
|
|
128 |
cond_vector = switti.text_pooler(cond_vector)
|
129 |
|
130 |
if switti.use_crop_cond:
|
131 |
+
crop_coords = get_crop_condition(2 * B * [TRAIN_IMAGE_SIZE[0]],
|
132 |
+
2 * B * [TRAIN_IMAGE_SIZE[1]],
|
133 |
).to(cond_vector.device)
|
134 |
crop_embed = switti.crop_embed(crop_coords.view(-1)).reshape(2 * B, switti.D)
|
135 |
crop_cond = switti.crop_proj(crop_embed)
|
|
|
175 |
if b.attn.caching and b.attn.cached_k is not None:
|
176 |
b.attn.cached_k = b.attn.cached_k[:B]
|
177 |
b.attn.cached_v = b.attn.cached_v[:B]
|
178 |
+
if b.cross_attn.caching and b.cross_attn.cached_k is not None:
|
179 |
b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
|
180 |
b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
|
181 |
else:
|
|
|
203 |
# default const cfg
|
204 |
t = cfg
|
205 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
206 |
+
elif last_scale_temp is not None:
|
207 |
logits_BlV = logits_BlV / last_scale_temp
|
208 |
|
209 |
if apply_smooth and si >= smooth_start_si:
|
|
|
214 |
)
|
215 |
h_BChw = idx_Bl @ vae_quant.embedding.weight.unsqueeze(0)
|
216 |
else:
|
217 |
+
# default nucleus sampling
|
218 |
idx_Bl = sample_with_top_k_top_p_(
|
219 |
logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1,
|
220 |
)[:, :, 0]
|
models/switti.py
CHANGED
@@ -9,7 +9,7 @@ from diffusers.models.embeddings import GaussianFourierProjection
|
|
9 |
|
10 |
from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
|
11 |
from models.rope import compute_axial_cis
|
12 |
-
|
13 |
|
14 |
def get_crop_condition(
|
15 |
heights: list,
|
@@ -53,7 +53,6 @@ class Switti(nn.Module):
|
|
53 |
use_swiglu_ffn=True,
|
54 |
use_ar=False,
|
55 |
use_crop_cond=True,
|
56 |
-
device='cuda',
|
57 |
):
|
58 |
super().__init__()
|
59 |
# 0. hyperparameters
|
@@ -392,20 +391,20 @@ class SwittiHF(Switti, PyTorchModelHubMixin):
|
|
392 |
use_swiglu_ffn=True,
|
393 |
use_ar=False,
|
394 |
use_crop_cond=True,
|
395 |
-
|
396 |
):
|
397 |
heads = depth
|
398 |
width = depth * 64
|
|
|
399 |
super().__init__(
|
400 |
depth=depth,
|
401 |
embed_dim=width,
|
402 |
num_heads=heads,
|
403 |
-
patch_nums=
|
404 |
rope=rope,
|
405 |
rope_theta=rope_theta,
|
406 |
rope_size=rope_size,
|
407 |
use_swiglu_ffn=use_swiglu_ffn,
|
408 |
use_ar=use_ar,
|
409 |
use_crop_cond=use_crop_cond,
|
410 |
-
device=device,
|
411 |
)
|
|
|
9 |
|
10 |
from models.basic_switti import AdaLNBeforeHead, AdaLNSelfCrossAttn
|
11 |
from models.rope import compute_axial_cis
|
12 |
+
from models.helpers import RESOLUTION_PATCH_NUMS_MAPPING
|
13 |
|
14 |
def get_crop_condition(
|
15 |
heights: list,
|
|
|
53 |
use_swiglu_ffn=True,
|
54 |
use_ar=False,
|
55 |
use_crop_cond=True,
|
|
|
56 |
):
|
57 |
super().__init__()
|
58 |
# 0. hyperparameters
|
|
|
391 |
use_swiglu_ffn=True,
|
392 |
use_ar=False,
|
393 |
use_crop_cond=True,
|
394 |
+
reso=512,
|
395 |
):
|
396 |
heads = depth
|
397 |
width = depth * 64
|
398 |
+
patch_nums = tuple([int(x) for x in RESOLUTION_PATCH_NUMS_MAPPING[reso].split("_")])
|
399 |
super().__init__(
|
400 |
depth=depth,
|
401 |
embed_dim=width,
|
402 |
num_heads=heads,
|
403 |
+
patch_nums=patch_nums,
|
404 |
rope=rope,
|
405 |
rope_theta=rope_theta,
|
406 |
rope_size=rope_size,
|
407 |
use_swiglu_ffn=use_swiglu_ffn,
|
408 |
use_ar=use_ar,
|
409 |
use_crop_cond=use_crop_cond,
|
|
|
410 |
)
|
models/vqvae.py
CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
13 |
|
14 |
from .basic_vae import Decoder, Encoder
|
15 |
from .quant import VectorQuantizer2
|
16 |
-
|
17 |
|
18 |
|
19 |
class VQVAE(nn.Module):
|
@@ -172,8 +172,9 @@ class VQVAEHF(VQVAE, PyTorchModelHubMixin):
|
|
172 |
ch=160,
|
173 |
test_mode=True,
|
174 |
share_quant_resi=4,
|
175 |
-
|
176 |
):
|
|
|
177 |
super().__init__(
|
178 |
vocab_size=vocab_size,
|
179 |
z_channels=z_channels,
|
|
|
13 |
|
14 |
from .basic_vae import Decoder, Encoder
|
15 |
from .quant import VectorQuantizer2
|
16 |
+
from models.helpers import RESOLUTION_PATCH_NUMS_MAPPING
|
17 |
|
18 |
|
19 |
class VQVAE(nn.Module):
|
|
|
172 |
ch=160,
|
173 |
test_mode=True,
|
174 |
share_quant_resi=4,
|
175 |
+
reso=1024,
|
176 |
):
|
177 |
+
v_patch_nums = tuple((int(x) for x in RESOLUTION_PATCH_NUMS_MAPPING[reso].split("_")))
|
178 |
super().__init__(
|
179 |
vocab_size=vocab_size,
|
180 |
z_channels=z_channels,
|