daoyuan98 commited on
Commit
2f7ddaa
·
verified ·
1 Parent(s): 71aa4e1

fix model shape error

Browse files
Files changed (1) hide show
  1. model.py +7 -8
model.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import torch
4
 
5
  from torch import Tensor, nn
6
-
7
  from layers import (DoubleStreamBlock, EmbedND, LastLayer,
8
  MLPEmbedder, SingleStreamBlock,
9
  timestep_embedding)
@@ -11,11 +10,6 @@ from layers import (DoubleStreamBlock, EmbedND, LastLayer,
11
  import torch.distributed as dist
12
  from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
13
 
14
- from accelerate.logging import get_logger
15
- logger = get_logger(__name__, log_level="INFO")
16
-
17
-
18
-
19
 
20
  @dataclass
21
  class FluxParams:
@@ -27,7 +21,7 @@ class FluxParams:
27
  num_heads: int
28
  depth: int
29
  depth_single_blocks: int
30
- axes_dim: list[int]
31
  theta: int
32
  qkv_bias: bool
33
  guidance_embed: bool
@@ -162,6 +156,11 @@ class Flux(nn.Module):
162
  ip_scale: Tensor = 1.0,
163
  return_intermediate: bool = False,
164
  ):
 
 
 
 
 
165
 
166
  if return_intermediate:
167
  intermediate_double = []
@@ -271,4 +270,4 @@ class Flux(nn.Module):
271
  if return_intermediate:
272
  return img, intermediate_double, intermediate_single
273
  else:
274
- return img
 
3
  import torch
4
 
5
  from torch import Tensor, nn
 
6
  from layers import (DoubleStreamBlock, EmbedND, LastLayer,
7
  MLPEmbedder, SingleStreamBlock,
8
  timestep_embedding)
 
10
  import torch.distributed as dist
11
  from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
12
 
 
 
 
 
 
13
 
14
  @dataclass
15
  class FluxParams:
 
21
  num_heads: int
22
  depth: int
23
  depth_single_blocks: int
24
+ axes_dim: list
25
  theta: int
26
  qkv_bias: bool
27
  guidance_embed: bool
 
156
  ip_scale: Tensor = 1.0,
157
  return_intermediate: bool = False,
158
  ):
159
+ inputs = [img, img_ids, txt, txt_ids, timesteps, y]
160
+ for i, input in enumerate(inputs):
161
+ if input.shape[0] != 1:
162
+ inputs[i] = input.unsqueeze(0)
163
+ img, img_ids, txt, txt_ids, timestpes, y = inputs
164
 
165
  if return_intermediate:
166
  intermediate_double = []
 
270
  if return_intermediate:
271
  return img, intermediate_double, intermediate_single
272
  else:
273
+ return img.squeeze()