Severian commited on
Commit
9bc81e0
·
verified ·
1 Parent(s): 5bbe627

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -80
app.py CHANGED
@@ -261,90 +261,115 @@ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", m
261
  image_adapter.eval()
262
  image_adapter.to("cuda")
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  @spaces.GPU()
266
  @torch.no_grad()
267
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, art_style: str) -> str:
268
- torch.cuda.empty_cache()
269
-
270
- # 'any' means no length specified
271
- length = None if caption_length == "any" else caption_length
272
-
273
- if isinstance(length, str):
274
- try:
275
- length = int(length)
276
- except ValueError:
277
- pass
278
-
279
- # 'rng-tags' and 'training_prompt' don't have formal/informal tones
280
- if caption_type == "rng-tags" or caption_type == "training_prompt":
281
- caption_tone = "formal"
282
-
283
- # Build prompt
284
- prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
285
- if prompt_key not in CAPTION_TYPE_MAP:
286
- raise ValueError(f"Invalid caption type: {prompt_key}")
287
-
288
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(
289
- length=length,
290
- word_count=length,
291
- style=art_style,
292
- style_characteristics=STYLE_CHARACTERISTICS.get(art_style, "its unique elements"),
293
- style_focus=STYLE_FOCUS.get(art_style, "its distinctive features")
294
- )
295
- print(f"Prompt: {prompt_str}")
296
-
297
- # Preprocess image
298
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
299
- image = input_image.resize((384, 384), Image.LANCZOS)
300
- pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
301
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
302
- pixel_values = pixel_values.to('cuda')
303
-
304
- # Tokenize the prompt
305
- prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
306
-
307
- # Embed image
308
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
309
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
310
- image_features = vision_outputs.hidden_states
311
- embedded_images = image_adapter(image_features)
312
- embedded_images = embedded_images.to('cuda')
313
-
314
- # Embed prompt
315
- prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
316
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
317
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
318
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
319
-
320
- # Construct prompts
321
- inputs_embeds = torch.cat([
322
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
323
- embedded_images.to(dtype=embedded_bos.dtype),
324
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
325
- eot_embed.expand(embedded_images.shape[0], -1, -1),
326
- ], dim=1)
327
-
328
- input_ids = torch.cat([
329
- torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
330
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
331
- prompt,
332
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
333
- ], dim=1).to('cuda')
334
- attention_mask = torch.ones_like(input_ids)
335
-
336
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
337
- #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
338
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
339
-
340
- # Trim off the prompt
341
- generate_ids = generate_ids[:, input_ids.shape[1]:]
342
- if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
343
- generate_ids = generate_ids[:, :-1]
344
-
345
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
346
-
347
- return caption.strip()
 
 
 
 
348
 
349
  css = """
350
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
 
261
  image_adapter.eval()
262
  image_adapter.to("cuda")
263
 
264
+ # After loading the tokenizer and model
265
+ print(f"Tokenizer class: {type(tokenizer)}")
266
+ print(f"BOS token: {tokenizer.bos_token}")
267
+ print(f"BOS token ID: {tokenizer.bos_token_id}")
268
+ print(f"EOS token: {tokenizer.eos_token}")
269
+ print(f"EOS token ID: {tokenizer.eos_token_id}")
270
+ print(f"Text model device: {text_model.device}")
271
+
272
+ # Ensure the tokenizer has the necessary special tokens
273
+ if tokenizer.bos_token_id is None or tokenizer.eos_token_id is None:
274
+ print("Warning: BOS or EOS token is missing. Adding default tokens.")
275
+ special_tokens_dict = {}
276
+ if tokenizer.bos_token_id is None:
277
+ special_tokens_dict['bos_token'] = '<|endoftext|>'
278
+ if tokenizer.eos_token_id is None:
279
+ special_tokens_dict['eos_token'] = '<|endoftext|>'
280
+ num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
281
+ print(f"Added {num_added_tokens} special tokens to the tokenizer.")
282
+
283
+ # Resize token embeddings of the model if new tokens are added
284
+ text_model.resize_token_embeddings(len(tokenizer))
285
 
286
  @spaces.GPU()
287
  @torch.no_grad()
288
  def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, art_style: str) -> str:
289
+ torch.cuda.empty_cache()
290
+
291
+ # 'any' means no length specified
292
+ length = None if caption_length == "any" else caption_length
293
+
294
+ if isinstance(length, str):
295
+ try:
296
+ length = int(length)
297
+ except ValueError:
298
+ pass
299
+
300
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
301
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
302
+ caption_tone = "formal"
303
+
304
+ # Build prompt
305
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
306
+ if prompt_key not in CAPTION_TYPE_MAP:
307
+ raise ValueError(f"Invalid caption type: {prompt_key}")
308
+
309
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(
310
+ length=length,
311
+ word_count=length,
312
+ style=art_style,
313
+ style_characteristics=STYLE_CHARACTERISTICS.get(art_style, "its unique elements"),
314
+ style_focus=STYLE_FOCUS.get(art_style, "its distinctive features")
315
+ )
316
+ print(f"Prompt: {prompt_str}")
317
+
318
+ # Preprocess image
319
+ image = input_image.resize((384, 384), Image.LANCZOS)
320
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
321
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
322
+ pixel_values = pixel_values.to('cuda')
323
+
324
+ # Tokenize the prompt
325
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
326
+
327
+ # Embed image
328
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
329
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
330
+ image_features = vision_outputs.hidden_states
331
+ embedded_images = image_adapter(image_features)
332
+ embedded_images = embedded_images.to('cuda')
333
+
334
+ # Embed prompt
335
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
336
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
337
+
338
+ # Check for bos_token_id and provide a fallback
339
+ bos_token_id = tokenizer.bos_token_id
340
+ if bos_token_id is None:
341
+ print("Warning: bos_token_id is None. Using default value of 1.")
342
+ bos_token_id = 1 # Common default, but may need adjustment
343
+
344
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
345
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
346
+
347
+ # Construct prompts
348
+ inputs_embeds = torch.cat([
349
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
350
+ embedded_images.to(dtype=embedded_bos.dtype),
351
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
352
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
353
+ ], dim=1)
354
+
355
+ input_ids = torch.cat([
356
+ torch.tensor([[bos_token_id]], dtype=torch.long),
357
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
358
+ prompt,
359
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
360
+ ], dim=1).to('cuda')
361
+ attention_mask = torch.ones_like(input_ids)
362
+
363
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
364
+
365
+ # Trim off the prompt
366
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
367
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
368
+ generate_ids = generate_ids[:, :-1]
369
+
370
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
371
+
372
+ return caption.strip()
373
 
374
  css = """
375
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {