Julien Blanchon commited on
Commit
98bf3ed
·
1 Parent(s): 8f9fc5b
Files changed (2) hide show
  1. app.py +14 -8
  2. tim/models/utils/text_encoders.py +3 -3
app.py CHANGED
@@ -24,11 +24,13 @@ model = None
24
  scheduler = None
25
  decode_func = None
26
  config = None
 
 
27
 
28
 
29
  def load_model_components(device: str = "cuda"):
30
  """Load all model components once at startup"""
31
- global model, scheduler, decode_func, config
32
 
33
  try:
34
  # Load configuration
@@ -62,6 +64,13 @@ def load_model_components(device: str = "cuda"):
62
  else:
63
  raise ValueError("Unsupported VAE type")
64
 
 
 
 
 
 
 
 
65
  print("Loading main model...")
66
  # Load main model
67
  model = instantiate_from_config(model_config.network).to(
@@ -143,16 +152,13 @@ def generate_image(
143
  progress(0.1, desc="Loading text encoder...")
144
 
145
  # Load text encoder
146
- text_encoder, tokenizer = load_text_encoder(
147
- text_encoder_dir=config.model.text_encoder_dir,
148
- device=device,
149
- weight_dtype=dtype,
150
- )
151
 
152
  # Encode prompt
153
  cap_features, cap_mask = encode_prompt(
154
  tokenizer,
155
- text_encoder,
156
  device,
157
  dtype,
158
  [prompt],
@@ -163,7 +169,7 @@ def generate_image(
163
  # Encode null caption for CFG
164
  null_cap_feat, null_cap_mask = encode_prompt(
165
  tokenizer,
166
- text_encoder,
167
  device,
168
  dtype,
169
  [""],
 
24
  scheduler = None
25
  decode_func = None
26
  config = None
27
+ text_encoder = None
28
+ tokenizer = None
29
 
30
 
31
  def load_model_components(device: str = "cuda"):
32
  """Load all model components once at startup"""
33
+ global model, scheduler, decode_func, config, text_encoder, tokenizer
34
 
35
  try:
36
  # Load configuration
 
64
  else:
65
  raise ValueError("Unsupported VAE type")
66
 
67
+ # Load text encoder
68
+ text_encoder, tokenizer = load_text_encoder(
69
+ text_encoder_dir=config.model.text_encoder_dir,
70
+ device=device,
71
+ weight_dtype=dtype,
72
+ )
73
+
74
  print("Loading main model...")
75
  # Load main model
76
  model = instantiate_from_config(model_config.network).to(
 
152
  progress(0.1, desc="Loading text encoder...")
153
 
154
  # Load text encoder
155
+ text_encoder.to(device)
156
+ text_encoder.set_attn_implementation("flash_attention_2")
 
 
 
157
 
158
  # Encode prompt
159
  cap_features, cap_mask = encode_prompt(
160
  tokenizer,
161
+ text_encoder.model,
162
  device,
163
  dtype,
164
  [prompt],
 
169
  # Encode null caption for CFG
170
  null_cap_feat, null_cap_mask = encode_prompt(
171
  tokenizer,
172
+ text_encoder.model,
173
  device,
174
  dtype,
175
  [""],
tim/models/utils/text_encoders.py CHANGED
@@ -11,10 +11,10 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
11
  tokenizer.padding_side = "right"
12
  text_encoder = Gemma3ForCausalLM.from_pretrained(
13
  text_encoder_dir,
14
- attn_implementation="flash_attention_2",
15
  device_map="cpu",
16
  dtype=weight_dtype,
17
- ).model
18
  elif "t5" in text_encoder_dir:
19
  text_encoder = T5EncoderModel.from_pretrained(
20
  text_encoder_dir,
@@ -28,7 +28,7 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
28
  # for param in text_encoder.parameters():
29
  # param.requires_grad = False
30
 
31
- text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
32
 
33
  return text_encoder, tokenizer
34
 
 
11
  tokenizer.padding_side = "right"
12
  text_encoder = Gemma3ForCausalLM.from_pretrained(
13
  text_encoder_dir,
14
+ attn_implementation="sdpa",
15
  device_map="cpu",
16
  dtype=weight_dtype,
17
+ )
18
  elif "t5" in text_encoder_dir:
19
  text_encoder = T5EncoderModel.from_pretrained(
20
  text_encoder_dir,
 
28
  # for param in text_encoder.parameters():
29
  # param.requires_grad = False
30
 
31
+ text_encoder.model = text_encoder.model.eval().to(device=device, dtype=weight_dtype)
32
 
33
  return text_encoder, tokenizer
34