Spaces:
Running
on
Zero
Running
on
Zero
fix model shape error
Browse files
model.py
CHANGED
@@ -3,7 +3,6 @@ import numpy as np
|
|
3 |
import torch
|
4 |
|
5 |
from torch import Tensor, nn
|
6 |
-
|
7 |
from layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
8 |
MLPEmbedder, SingleStreamBlock,
|
9 |
timestep_embedding)
|
@@ -11,11 +10,6 @@ from layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
11 |
import torch.distributed as dist
|
12 |
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
|
13 |
|
14 |
-
from accelerate.logging import get_logger
|
15 |
-
logger = get_logger(__name__, log_level="INFO")
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
@dataclass
|
21 |
class FluxParams:
|
@@ -27,7 +21,7 @@ class FluxParams:
|
|
27 |
num_heads: int
|
28 |
depth: int
|
29 |
depth_single_blocks: int
|
30 |
-
axes_dim: list
|
31 |
theta: int
|
32 |
qkv_bias: bool
|
33 |
guidance_embed: bool
|
@@ -162,6 +156,11 @@ class Flux(nn.Module):
|
|
162 |
ip_scale: Tensor = 1.0,
|
163 |
return_intermediate: bool = False,
|
164 |
):
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
if return_intermediate:
|
167 |
intermediate_double = []
|
@@ -271,4 +270,4 @@ class Flux(nn.Module):
|
|
271 |
if return_intermediate:
|
272 |
return img, intermediate_double, intermediate_single
|
273 |
else:
|
274 |
-
return img
|
|
|
3 |
import torch
|
4 |
|
5 |
from torch import Tensor, nn
|
|
|
6 |
from layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
7 |
MLPEmbedder, SingleStreamBlock,
|
8 |
timestep_embedding)
|
|
|
10 |
import torch.distributed as dist
|
11 |
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
@dataclass
|
15 |
class FluxParams:
|
|
|
21 |
num_heads: int
|
22 |
depth: int
|
23 |
depth_single_blocks: int
|
24 |
+
axes_dim: list
|
25 |
theta: int
|
26 |
qkv_bias: bool
|
27 |
guidance_embed: bool
|
|
|
156 |
ip_scale: Tensor = 1.0,
|
157 |
return_intermediate: bool = False,
|
158 |
):
|
159 |
+
inputs = [img, img_ids, txt, txt_ids, timesteps, y]
|
160 |
+
for i, input in enumerate(inputs):
|
161 |
+
if input.shape[0] != 1:
|
162 |
+
inputs[i] = input.unsqueeze(0)
|
163 |
+
img, img_ids, txt, txt_ids, timestpes, y = inputs
|
164 |
|
165 |
if return_intermediate:
|
166 |
intermediate_double = []
|
|
|
270 |
if return_intermediate:
|
271 |
return img, intermediate_double, intermediate_single
|
272 |
else:
|
273 |
+
return img.squeeze()
|