Julien Blanchon commited on
Commit
e803333
·
1 Parent(s): 6fb693e
Files changed (2) hide show
  1. app.py +22 -37
  2. tim/models/utils/text_encoders.py +1 -1
app.py CHANGED
@@ -22,25 +22,13 @@ MAX_IMAGE_SIZE = 2048
22
  # Global variables to store loaded components
23
  model = None
24
  scheduler = None
25
- text_encoder = None
26
- tokenizer = None
27
  decode_func = None
28
- null_cap_feat = None
29
- null_cap_mask = None
30
  config = None
31
 
32
 
33
  def load_model_components(device: str = "cuda"):
34
  """Load all model components once at startup"""
35
- global \
36
- model, \
37
- scheduler, \
38
- text_encoder, \
39
- tokenizer, \
40
- decode_func, \
41
- null_cap_feat, \
42
- null_cap_mask, \
43
- config
44
 
45
  try:
46
  # Load configuration
@@ -74,26 +62,6 @@ def load_model_components(device: str = "cuda"):
74
  else:
75
  raise ValueError("Unsupported VAE type")
76
 
77
- print("Loading text encoder...")
78
- # Load text encoder
79
- text_encoder, tokenizer = load_text_encoder(
80
- text_encoder_dir=model_config.text_encoder_dir,
81
- device=device,
82
- weight_dtype=torch.bfloat16,
83
- )
84
-
85
- print("Encoding null caption...")
86
- # Get null caption features
87
- null_cap_feat, null_cap_mask = encode_prompt(
88
- tokenizer,
89
- text_encoder,
90
- device,
91
- torch.bfloat16,
92
- [""],
93
- model_config.use_last_hidden_state,
94
- max_seq_length=model_config.max_seq_length,
95
- )
96
-
97
  print("Loading main model...")
98
  # Load main model
99
  model = instantiate_from_config(model_config.network).to(
@@ -129,6 +97,8 @@ def generate_image(
129
  ):
130
  """Generate image from text prompt"""
131
  try:
 
 
132
  # Validate inputs
133
  if not prompt or len(prompt.strip()) == 0:
134
  raise ValueError("Please enter a valid prompt")
@@ -136,9 +106,6 @@ def generate_image(
136
  if model is None or scheduler is None:
137
  raise RuntimeError("Model components not loaded. Please check the setup.")
138
 
139
- if device == "cuda":
140
- model.set_attn_implementation("flash_attention_2")
141
-
142
  # Validate dimensions
143
  if (
144
  width < 256
@@ -173,7 +140,14 @@ def generate_image(
173
  generator=generator,
174
  )
175
 
176
- progress(0.1, desc="Encoding prompt...")
 
 
 
 
 
 
 
177
 
178
  # Encode prompt
179
  cap_features, cap_mask = encode_prompt(
@@ -186,6 +160,17 @@ def generate_image(
186
  max_seq_length=config.model.max_seq_length,
187
  )
188
 
 
 
 
 
 
 
 
 
 
 
 
189
  cur_max_seq_len = cap_mask.sum(dim=-1).max()
190
  y = cap_features[:, :cur_max_seq_len]
191
 
 
22
  # Global variables to store loaded components
23
  model = None
24
  scheduler = None
 
 
25
  decode_func = None
 
 
26
  config = None
27
 
28
 
29
  def load_model_components(device: str = "cuda"):
30
  """Load all model components once at startup"""
31
+ global model, scheduler, decode_func, config
 
 
 
 
 
 
 
 
32
 
33
  try:
34
  # Load configuration
 
62
  else:
63
  raise ValueError("Unsupported VAE type")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  print("Loading main model...")
66
  # Load main model
67
  model = instantiate_from_config(model_config.network).to(
 
97
  ):
98
  """Generate image from text prompt"""
99
  try:
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ print(f"Using device: {device}")
102
  # Validate inputs
103
  if not prompt or len(prompt.strip()) == 0:
104
  raise ValueError("Please enter a valid prompt")
 
106
  if model is None or scheduler is None:
107
  raise RuntimeError("Model components not loaded. Please check the setup.")
108
 
 
 
 
109
  # Validate dimensions
110
  if (
111
  width < 256
 
140
  generator=generator,
141
  )
142
 
143
+ progress(0.1, desc="Loading text encoder...")
144
+
145
+ # Load text encoder
146
+ text_encoder, tokenizer = load_text_encoder(
147
+ text_encoder_dir=config.model.text_encoder_dir,
148
+ device=device,
149
+ weight_dtype=dtype,
150
+ )
151
 
152
  # Encode prompt
153
  cap_features, cap_mask = encode_prompt(
 
160
  max_seq_length=config.model.max_seq_length,
161
  )
162
 
163
+ # Encode null caption for CFG
164
+ null_cap_feat, null_cap_mask = encode_prompt(
165
+ tokenizer,
166
+ text_encoder,
167
+ device,
168
+ dtype,
169
+ [""],
170
+ config.model.use_last_hidden_state,
171
+ max_seq_length=config.model.max_seq_length,
172
+ )
173
+
174
  cur_max_seq_len = cap_mask.sum(dim=-1).max()
175
  y = cap_features[:, :cur_max_seq_len]
176
 
tim/models/utils/text_encoders.py CHANGED
@@ -11,7 +11,7 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
11
  tokenizer.padding_side = "right"
12
  text_encoder = Gemma3ForCausalLM.from_pretrained(
13
  text_encoder_dir,
14
- attn_implementation="sdpa",
15
  device_map="cpu",
16
  dtype=weight_dtype,
17
  ).model
 
11
  tokenizer.padding_side = "right"
12
  text_encoder = Gemma3ForCausalLM.from_pretrained(
13
  text_encoder_dir,
14
+ attn_implementation="flash_attention_2",
15
  device_map="cpu",
16
  dtype=weight_dtype,
17
  ).model