Update app.py
Browse files
app.py
CHANGED
@@ -261,90 +261,115 @@ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", m
|
|
261 |
image_adapter.eval()
|
262 |
image_adapter.to("cuda")
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
@spaces.GPU()
|
266 |
@torch.no_grad()
|
267 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, art_style: str) -> str:
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
348 |
|
349 |
css = """
|
350 |
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
|
|
|
261 |
image_adapter.eval()
|
262 |
image_adapter.to("cuda")
|
263 |
|
264 |
+
# After loading the tokenizer and model
|
265 |
+
print(f"Tokenizer class: {type(tokenizer)}")
|
266 |
+
print(f"BOS token: {tokenizer.bos_token}")
|
267 |
+
print(f"BOS token ID: {tokenizer.bos_token_id}")
|
268 |
+
print(f"EOS token: {tokenizer.eos_token}")
|
269 |
+
print(f"EOS token ID: {tokenizer.eos_token_id}")
|
270 |
+
print(f"Text model device: {text_model.device}")
|
271 |
+
|
272 |
+
# Ensure the tokenizer has the necessary special tokens
|
273 |
+
if tokenizer.bos_token_id is None or tokenizer.eos_token_id is None:
|
274 |
+
print("Warning: BOS or EOS token is missing. Adding default tokens.")
|
275 |
+
special_tokens_dict = {}
|
276 |
+
if tokenizer.bos_token_id is None:
|
277 |
+
special_tokens_dict['bos_token'] = '<|endoftext|>'
|
278 |
+
if tokenizer.eos_token_id is None:
|
279 |
+
special_tokens_dict['eos_token'] = '<|endoftext|>'
|
280 |
+
num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
281 |
+
print(f"Added {num_added_tokens} special tokens to the tokenizer.")
|
282 |
+
|
283 |
+
# Resize token embeddings of the model if new tokens are added
|
284 |
+
text_model.resize_token_embeddings(len(tokenizer))
|
285 |
|
286 |
@spaces.GPU()
|
287 |
@torch.no_grad()
|
288 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, art_style: str) -> str:
|
289 |
+
torch.cuda.empty_cache()
|
290 |
+
|
291 |
+
# 'any' means no length specified
|
292 |
+
length = None if caption_length == "any" else caption_length
|
293 |
+
|
294 |
+
if isinstance(length, str):
|
295 |
+
try:
|
296 |
+
length = int(length)
|
297 |
+
except ValueError:
|
298 |
+
pass
|
299 |
+
|
300 |
+
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
301 |
+
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
302 |
+
caption_tone = "formal"
|
303 |
+
|
304 |
+
# Build prompt
|
305 |
+
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
306 |
+
if prompt_key not in CAPTION_TYPE_MAP:
|
307 |
+
raise ValueError(f"Invalid caption type: {prompt_key}")
|
308 |
+
|
309 |
+
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(
|
310 |
+
length=length,
|
311 |
+
word_count=length,
|
312 |
+
style=art_style,
|
313 |
+
style_characteristics=STYLE_CHARACTERISTICS.get(art_style, "its unique elements"),
|
314 |
+
style_focus=STYLE_FOCUS.get(art_style, "its distinctive features")
|
315 |
+
)
|
316 |
+
print(f"Prompt: {prompt_str}")
|
317 |
+
|
318 |
+
# Preprocess image
|
319 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
320 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
321 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
322 |
+
pixel_values = pixel_values.to('cuda')
|
323 |
+
|
324 |
+
# Tokenize the prompt
|
325 |
+
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
326 |
+
|
327 |
+
# Embed image
|
328 |
+
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
329 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
330 |
+
image_features = vision_outputs.hidden_states
|
331 |
+
embedded_images = image_adapter(image_features)
|
332 |
+
embedded_images = embedded_images.to('cuda')
|
333 |
+
|
334 |
+
# Embed prompt
|
335 |
+
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
336 |
+
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
|
337 |
+
|
338 |
+
# Check for bos_token_id and provide a fallback
|
339 |
+
bos_token_id = tokenizer.bos_token_id
|
340 |
+
if bos_token_id is None:
|
341 |
+
print("Warning: bos_token_id is None. Using default value of 1.")
|
342 |
+
bos_token_id = 1 # Common default, but may need adjustment
|
343 |
+
|
344 |
+
embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
|
345 |
+
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
346 |
+
|
347 |
+
# Construct prompts
|
348 |
+
inputs_embeds = torch.cat([
|
349 |
+
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
350 |
+
embedded_images.to(dtype=embedded_bos.dtype),
|
351 |
+
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
352 |
+
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
353 |
+
], dim=1)
|
354 |
+
|
355 |
+
input_ids = torch.cat([
|
356 |
+
torch.tensor([[bos_token_id]], dtype=torch.long),
|
357 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
358 |
+
prompt,
|
359 |
+
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
360 |
+
], dim=1).to('cuda')
|
361 |
+
attention_mask = torch.ones_like(input_ids)
|
362 |
+
|
363 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
|
364 |
+
|
365 |
+
# Trim off the prompt
|
366 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
367 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
368 |
+
generate_ids = generate_ids[:, :-1]
|
369 |
+
|
370 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
371 |
+
|
372 |
+
return caption.strip()
|
373 |
|
374 |
css = """
|
375 |
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
|