Julien Blanchon
commited on
Commit
·
98bf3ed
1
Parent(s):
8f9fc5b
soijds
Browse files- app.py +14 -8
- tim/models/utils/text_encoders.py +3 -3
app.py
CHANGED
@@ -24,11 +24,13 @@ model = None
|
|
24 |
scheduler = None
|
25 |
decode_func = None
|
26 |
config = None
|
|
|
|
|
27 |
|
28 |
|
29 |
def load_model_components(device: str = "cuda"):
|
30 |
"""Load all model components once at startup"""
|
31 |
-
global model, scheduler, decode_func, config
|
32 |
|
33 |
try:
|
34 |
# Load configuration
|
@@ -62,6 +64,13 @@ def load_model_components(device: str = "cuda"):
|
|
62 |
else:
|
63 |
raise ValueError("Unsupported VAE type")
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
print("Loading main model...")
|
66 |
# Load main model
|
67 |
model = instantiate_from_config(model_config.network).to(
|
@@ -143,16 +152,13 @@ def generate_image(
|
|
143 |
progress(0.1, desc="Loading text encoder...")
|
144 |
|
145 |
# Load text encoder
|
146 |
-
text_encoder
|
147 |
-
|
148 |
-
device=device,
|
149 |
-
weight_dtype=dtype,
|
150 |
-
)
|
151 |
|
152 |
# Encode prompt
|
153 |
cap_features, cap_mask = encode_prompt(
|
154 |
tokenizer,
|
155 |
-
text_encoder,
|
156 |
device,
|
157 |
dtype,
|
158 |
[prompt],
|
@@ -163,7 +169,7 @@ def generate_image(
|
|
163 |
# Encode null caption for CFG
|
164 |
null_cap_feat, null_cap_mask = encode_prompt(
|
165 |
tokenizer,
|
166 |
-
text_encoder,
|
167 |
device,
|
168 |
dtype,
|
169 |
[""],
|
|
|
24 |
scheduler = None
|
25 |
decode_func = None
|
26 |
config = None
|
27 |
+
text_encoder = None
|
28 |
+
tokenizer = None
|
29 |
|
30 |
|
31 |
def load_model_components(device: str = "cuda"):
|
32 |
"""Load all model components once at startup"""
|
33 |
+
global model, scheduler, decode_func, config, text_encoder, tokenizer
|
34 |
|
35 |
try:
|
36 |
# Load configuration
|
|
|
64 |
else:
|
65 |
raise ValueError("Unsupported VAE type")
|
66 |
|
67 |
+
# Load text encoder
|
68 |
+
text_encoder, tokenizer = load_text_encoder(
|
69 |
+
text_encoder_dir=config.model.text_encoder_dir,
|
70 |
+
device=device,
|
71 |
+
weight_dtype=dtype,
|
72 |
+
)
|
73 |
+
|
74 |
print("Loading main model...")
|
75 |
# Load main model
|
76 |
model = instantiate_from_config(model_config.network).to(
|
|
|
152 |
progress(0.1, desc="Loading text encoder...")
|
153 |
|
154 |
# Load text encoder
|
155 |
+
text_encoder.to(device)
|
156 |
+
text_encoder.set_attn_implementation("flash_attention_2")
|
|
|
|
|
|
|
157 |
|
158 |
# Encode prompt
|
159 |
cap_features, cap_mask = encode_prompt(
|
160 |
tokenizer,
|
161 |
+
text_encoder.model,
|
162 |
device,
|
163 |
dtype,
|
164 |
[prompt],
|
|
|
169 |
# Encode null caption for CFG
|
170 |
null_cap_feat, null_cap_mask = encode_prompt(
|
171 |
tokenizer,
|
172 |
+
text_encoder.model,
|
173 |
device,
|
174 |
dtype,
|
175 |
[""],
|
tim/models/utils/text_encoders.py
CHANGED
@@ -11,10 +11,10 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
|
|
11 |
tokenizer.padding_side = "right"
|
12 |
text_encoder = Gemma3ForCausalLM.from_pretrained(
|
13 |
text_encoder_dir,
|
14 |
-
attn_implementation="
|
15 |
device_map="cpu",
|
16 |
dtype=weight_dtype,
|
17 |
-
)
|
18 |
elif "t5" in text_encoder_dir:
|
19 |
text_encoder = T5EncoderModel.from_pretrained(
|
20 |
text_encoder_dir,
|
@@ -28,7 +28,7 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
|
|
28 |
# for param in text_encoder.parameters():
|
29 |
# param.requires_grad = False
|
30 |
|
31 |
-
text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
|
32 |
|
33 |
return text_encoder, tokenizer
|
34 |
|
|
|
11 |
tokenizer.padding_side = "right"
|
12 |
text_encoder = Gemma3ForCausalLM.from_pretrained(
|
13 |
text_encoder_dir,
|
14 |
+
attn_implementation="sdpa",
|
15 |
device_map="cpu",
|
16 |
dtype=weight_dtype,
|
17 |
+
)
|
18 |
elif "t5" in text_encoder_dir:
|
19 |
text_encoder = T5EncoderModel.from_pretrained(
|
20 |
text_encoder_dir,
|
|
|
28 |
# for param in text_encoder.parameters():
|
29 |
# param.requires_grad = False
|
30 |
|
31 |
+
text_encoder.model = text_encoder.model.eval().to(device=device, dtype=weight_dtype)
|
32 |
|
33 |
return text_encoder, tokenizer
|
34 |
|