Julien Blanchon
commited on
Commit
·
e803333
1
Parent(s):
6fb693e
Updatr
Browse files- app.py +22 -37
- tim/models/utils/text_encoders.py +1 -1
app.py
CHANGED
@@ -22,25 +22,13 @@ MAX_IMAGE_SIZE = 2048
|
|
22 |
# Global variables to store loaded components
|
23 |
model = None
|
24 |
scheduler = None
|
25 |
-
text_encoder = None
|
26 |
-
tokenizer = None
|
27 |
decode_func = None
|
28 |
-
null_cap_feat = None
|
29 |
-
null_cap_mask = None
|
30 |
config = None
|
31 |
|
32 |
|
33 |
def load_model_components(device: str = "cuda"):
|
34 |
"""Load all model components once at startup"""
|
35 |
-
global
|
36 |
-
model, \
|
37 |
-
scheduler, \
|
38 |
-
text_encoder, \
|
39 |
-
tokenizer, \
|
40 |
-
decode_func, \
|
41 |
-
null_cap_feat, \
|
42 |
-
null_cap_mask, \
|
43 |
-
config
|
44 |
|
45 |
try:
|
46 |
# Load configuration
|
@@ -74,26 +62,6 @@ def load_model_components(device: str = "cuda"):
|
|
74 |
else:
|
75 |
raise ValueError("Unsupported VAE type")
|
76 |
|
77 |
-
print("Loading text encoder...")
|
78 |
-
# Load text encoder
|
79 |
-
text_encoder, tokenizer = load_text_encoder(
|
80 |
-
text_encoder_dir=model_config.text_encoder_dir,
|
81 |
-
device=device,
|
82 |
-
weight_dtype=torch.bfloat16,
|
83 |
-
)
|
84 |
-
|
85 |
-
print("Encoding null caption...")
|
86 |
-
# Get null caption features
|
87 |
-
null_cap_feat, null_cap_mask = encode_prompt(
|
88 |
-
tokenizer,
|
89 |
-
text_encoder,
|
90 |
-
device,
|
91 |
-
torch.bfloat16,
|
92 |
-
[""],
|
93 |
-
model_config.use_last_hidden_state,
|
94 |
-
max_seq_length=model_config.max_seq_length,
|
95 |
-
)
|
96 |
-
|
97 |
print("Loading main model...")
|
98 |
# Load main model
|
99 |
model = instantiate_from_config(model_config.network).to(
|
@@ -129,6 +97,8 @@ def generate_image(
|
|
129 |
):
|
130 |
"""Generate image from text prompt"""
|
131 |
try:
|
|
|
|
|
132 |
# Validate inputs
|
133 |
if not prompt or len(prompt.strip()) == 0:
|
134 |
raise ValueError("Please enter a valid prompt")
|
@@ -136,9 +106,6 @@ def generate_image(
|
|
136 |
if model is None or scheduler is None:
|
137 |
raise RuntimeError("Model components not loaded. Please check the setup.")
|
138 |
|
139 |
-
if device == "cuda":
|
140 |
-
model.set_attn_implementation("flash_attention_2")
|
141 |
-
|
142 |
# Validate dimensions
|
143 |
if (
|
144 |
width < 256
|
@@ -173,7 +140,14 @@ def generate_image(
|
|
173 |
generator=generator,
|
174 |
)
|
175 |
|
176 |
-
progress(0.1, desc="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
# Encode prompt
|
179 |
cap_features, cap_mask = encode_prompt(
|
@@ -186,6 +160,17 @@ def generate_image(
|
|
186 |
max_seq_length=config.model.max_seq_length,
|
187 |
)
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
cur_max_seq_len = cap_mask.sum(dim=-1).max()
|
190 |
y = cap_features[:, :cur_max_seq_len]
|
191 |
|
|
|
22 |
# Global variables to store loaded components
|
23 |
model = None
|
24 |
scheduler = None
|
|
|
|
|
25 |
decode_func = None
|
|
|
|
|
26 |
config = None
|
27 |
|
28 |
|
29 |
def load_model_components(device: str = "cuda"):
|
30 |
"""Load all model components once at startup"""
|
31 |
+
global model, scheduler, decode_func, config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
try:
|
34 |
# Load configuration
|
|
|
62 |
else:
|
63 |
raise ValueError("Unsupported VAE type")
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
print("Loading main model...")
|
66 |
# Load main model
|
67 |
model = instantiate_from_config(model_config.network).to(
|
|
|
97 |
):
|
98 |
"""Generate image from text prompt"""
|
99 |
try:
|
100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101 |
+
print(f"Using device: {device}")
|
102 |
# Validate inputs
|
103 |
if not prompt or len(prompt.strip()) == 0:
|
104 |
raise ValueError("Please enter a valid prompt")
|
|
|
106 |
if model is None or scheduler is None:
|
107 |
raise RuntimeError("Model components not loaded. Please check the setup.")
|
108 |
|
|
|
|
|
|
|
109 |
# Validate dimensions
|
110 |
if (
|
111 |
width < 256
|
|
|
140 |
generator=generator,
|
141 |
)
|
142 |
|
143 |
+
progress(0.1, desc="Loading text encoder...")
|
144 |
+
|
145 |
+
# Load text encoder
|
146 |
+
text_encoder, tokenizer = load_text_encoder(
|
147 |
+
text_encoder_dir=config.model.text_encoder_dir,
|
148 |
+
device=device,
|
149 |
+
weight_dtype=dtype,
|
150 |
+
)
|
151 |
|
152 |
# Encode prompt
|
153 |
cap_features, cap_mask = encode_prompt(
|
|
|
160 |
max_seq_length=config.model.max_seq_length,
|
161 |
)
|
162 |
|
163 |
+
# Encode null caption for CFG
|
164 |
+
null_cap_feat, null_cap_mask = encode_prompt(
|
165 |
+
tokenizer,
|
166 |
+
text_encoder,
|
167 |
+
device,
|
168 |
+
dtype,
|
169 |
+
[""],
|
170 |
+
config.model.use_last_hidden_state,
|
171 |
+
max_seq_length=config.model.max_seq_length,
|
172 |
+
)
|
173 |
+
|
174 |
cur_max_seq_len = cap_mask.sum(dim=-1).max()
|
175 |
y = cap_features[:, :cur_max_seq_len]
|
176 |
|
tim/models/utils/text_encoders.py
CHANGED
@@ -11,7 +11,7 @@ def load_text_encoder(text_encoder_dir, device, weight_dtype):
|
|
11 |
tokenizer.padding_side = "right"
|
12 |
text_encoder = Gemma3ForCausalLM.from_pretrained(
|
13 |
text_encoder_dir,
|
14 |
-
attn_implementation="
|
15 |
device_map="cpu",
|
16 |
dtype=weight_dtype,
|
17 |
).model
|
|
|
11 |
tokenizer.padding_side = "right"
|
12 |
text_encoder = Gemma3ForCausalLM.from_pretrained(
|
13 |
text_encoder_dir,
|
14 |
+
attn_implementation="flash_attention_2",
|
15 |
device_map="cpu",
|
16 |
dtype=weight_dtype,
|
17 |
).model
|