Severian commited on
Commit
bd55f23
·
verified ·
1 Parent(s): 799b4bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -242
app.py CHANGED
@@ -261,118 +261,84 @@ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", m
261
  image_adapter.eval()
262
  image_adapter.to("cuda")
263
 
264
- # After loading the tokenizer and model
265
- print(f"Tokenizer class: {type(tokenizer)}")
266
- print(f"BOS token: {tokenizer.bos_token}")
267
- print(f"BOS token ID: {tokenizer.bos_token_id}")
268
- print(f"EOS token: {tokenizer.eos_token}")
269
- print(f"EOS token ID: {tokenizer.eos_token_id}")
270
- print(f"Text model device: {text_model.device}")
271
-
272
- # Ensure the tokenizer has the necessary special tokens
273
- if tokenizer.bos_token_id is None or tokenizer.eos_token_id is None:
274
- print("Warning: BOS or EOS token is missing. Adding default tokens.")
275
- special_tokens_dict = {}
276
- if tokenizer.bos_token_id is None:
277
- special_tokens_dict['bos_token'] = '<|endoftext|>'
278
- if tokenizer.eos_token_id is None:
279
- special_tokens_dict['eos_token'] = '<|endoftext|>'
280
- num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict)
281
- print(f"Added {num_added_tokens} special tokens to the tokenizer.")
282
-
283
- # Resize token embeddings of the model if new tokens are added
284
- text_model.resize_token_embeddings(len(tokenizer))
285
 
286
  @spaces.GPU()
287
  @torch.no_grad()
288
- def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, art_style: str) -> str:
289
- torch.cuda.empty_cache()
290
-
291
- # Handle caption_length
292
- length = None
293
- if caption_length != "any":
294
- if isinstance(caption_length, int):
295
- length = caption_length
296
- elif isinstance(caption_length, str):
297
- try:
298
- length = int(caption_length)
299
- except ValueError:
300
- # If it's not a number, treat it as a descriptive length
301
- length = caption_length
302
-
303
- # 'rng-tags' and 'training_prompt' don't have formal/informal tones
304
- if caption_type in ["rng-tags", "training_prompt"]:
305
- caption_tone = "formal"
306
-
307
- # Build prompt
308
- prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
309
- if prompt_key not in CAPTION_TYPE_MAP:
310
- raise ValueError(f"Invalid caption type: {prompt_key}")
311
-
312
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(
313
- length=length,
314
- word_count=length,
315
- style=art_style,
316
- style_characteristics=STYLE_CHARACTERISTICS.get(art_style, "its unique elements"),
317
- style_focus=STYLE_FOCUS.get(art_style, "its distinctive features")
318
- )
319
- print(f"Prompt: {prompt_str}")
320
-
321
- # Preprocess image
322
- image = input_image.resize((384, 384), Image.LANCZOS)
323
- pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
324
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
325
- pixel_values = pixel_values.to('cuda')
326
-
327
- # Tokenize the prompt
328
- prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
329
-
330
- # Embed image
331
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
332
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
333
- image_features = vision_outputs.hidden_states
334
- embedded_images = image_adapter(image_features)
335
- embedded_images = embedded_images.to('cuda')
336
-
337
- # Embed prompt
338
- prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
339
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
340
-
341
- # Check for bos_token_id and provide a fallback
342
- bos_token_id = tokenizer.bos_token_id
343
- if bos_token_id is None:
344
- print("Warning: bos_token_id is None. Using default value of 1.")
345
- bos_token_id = 1 # Common default, but may need adjustment
346
-
347
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
348
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
349
-
350
- # Construct prompts
351
- inputs_embeds = torch.cat([
352
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
353
- embedded_images.to(dtype=embedded_bos.dtype),
354
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
355
- eot_embed.expand(embedded_images.shape[0], -1, -1),
356
- ], dim=1)
357
-
358
- input_ids = torch.cat([
359
- torch.tensor([[bos_token_id]], dtype=torch.long),
360
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
361
- prompt,
362
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
363
- ], dim=1).to('cuda')
364
- attention_mask = torch.ones_like(input_ids)
365
-
366
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
367
-
368
- # Trim off the prompt
369
- generate_ids = generate_ids[:, input_ids.shape[1]:]
370
- if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
371
- generate_ids = generate_ids[:, :-1]
372
-
373
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
374
-
375
- return caption.strip()
376
 
377
  css = """
378
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
@@ -392,71 +358,18 @@ ul, ol {
392
  }
393
  """
394
 
395
- ART_STYLES = [
396
- "Impressionism", "Cubism", "Surrealism", "Abstract Expressionism", "Pop Art",
397
- "Minimalism", "Baroque", "Renaissance", "Art Nouveau", "Gothic",
398
- "Romanticism", "Realism", "Expressionism", "Fauvism", "Art Deco",
399
- "Futurism", "Dadaism", "Pointillism", "Rococo", "Neoclassicism"
400
- ]
401
-
402
- STYLE_CHARACTERISTICS = {
403
- "Impressionism": "loose brushstrokes, emphasis on light and color, everyday subjects",
404
- "Cubism": "geometric shapes, multiple perspectives, fragmented forms",
405
- "Surrealism": "dreamlike imagery, unexpected juxtapositions, subconscious exploration",
406
- "Abstract Expressionism": "expressive brushwork, emotional content, abstract forms",
407
- "Pop Art": "bright colors, popular culture references, satire",
408
- "Minimalism": "simple forms, limited color palette, emphasis on space",
409
- "Baroque": "dramatic lighting, elaborate detail, grandeur",
410
- "Renaissance": "realistic depictions, perspective, religious themes",
411
- "Art Nouveau": "stylized forms, organic shapes, decorative elements",
412
- "Gothic": "dark themes, dramatic lighting, architectural elements",
413
- "Romanticism": "emotional content, nature scenes, idealized figures",
414
- "Realism": "detailed depictions, realistic textures, everyday subjects",
415
- "Expressionism": "emotional content, distorted forms, abstract elements",
416
- "Fauvism": "bold colors, abstract forms, emotional content",
417
- "Art Deco": "geometric shapes, streamlined forms, modern aesthetics",
418
- "Futurism": "dynamic forms, speed, technology",
419
- "Dadaism": "anti-art, absurdity, subversion of traditional art",
420
- "Pointillism": "small dots of color, impressionistic style, emphasis on light",
421
- "Rococo": "ornate style, lighthearted themes, decorative elements",
422
- "Neoclassicism": "classical style, balance, symmetry"
423
- }
424
-
425
- STYLE_FOCUS = {
426
- "Impressionism": "capturing fleeting moments and atmospheric effects",
427
- "Cubism": "deconstructing and reassembling forms from multiple viewpoints",
428
- "Surrealism": "creating a sense of the uncanny and exploring the subconscious mind",
429
- "Abstract Expressionism": "expressing emotional content through abstract forms",
430
- "Pop Art": "commenting on popular culture and satirizing consumerism",
431
- "Minimalism": "exploring the relationship between form and space",
432
- "Baroque": "creating dramatic and grandiose compositions",
433
- "Renaissance": "depicting realistic scenes and exploring perspective",
434
- "Art Nouveau": "incorporating organic and decorative elements",
435
- "Gothic": "exploring dark themes and dramatic lighting",
436
- "Romanticism": "depicting emotional scenes and idealized figures",
437
- "Realism": "capturing detailed and realistic textures",
438
- "Expressionism": "expressing emotional content through distorted forms",
439
- "Fauvism": "emphasizing bold colors and emotional content",
440
- "Art Deco": "incorporating geometric shapes and modern aesthetics",
441
- "Futurism": "depicting speed, technology, and dynamism",
442
- "Dadaism": "subverting traditional art and exploring absurdity",
443
- "Pointillism": "capturing light and color through small dots",
444
- "Rococo": "creating lighthearted and decorative compositions",
445
- "Neoclassicism": "achieving balance and symmetry in classical style"
446
- }
447
-
448
  with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
449
  with gr.Tab("Welcome"):
450
  gr.Markdown(
451
  """
452
- <img src="https://path-to-yamamoto-logo.png" alt="Yamamoto Logo" class="centered-image">
453
 
454
  # 🎨 Yamamoto JoyCaption: AI-Powered Art Inspiration
455
 
456
  ## Accelerate Your Creative Workflow with Intelligent Image Analysis
457
 
458
  This innovative tool empowers Yamamoto's artists to quickly generate descriptive captions,<br>
459
- training prompts, or tags from existing artwork, fueling the creative process for GenAI models.
460
 
461
  ## 🚀 How It Works:
462
  1. **Upload Your Inspiration**: Drop in an image (e.g., a charcoal horse picture) that embodies your desired style.
@@ -468,109 +381,72 @@ with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
468
 
469
  with gr.Tab("JoyCaption"):
470
  gr.Markdown("""
471
- # JoyCaption: AI-Powered Image Analysis Tool
472
 
473
- This tool helps you generate various types of text based on an uploaded image. Here's how to use it:
474
 
475
- 1. Upload an image
476
- 2. Choose your desired output type
477
- 3. Adjust settings as needed
478
- 4. Click 'Generate Caption' to get your result
479
- """)
 
 
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  with gr.Row():
482
- with gr.Column(scale=1):
483
- input_image = gr.Image(type="pil", label="Upload Your Image")
484
 
485
  caption_type = gr.Dropdown(
486
- choices=[
487
- "descriptive",
488
- "training_prompt",
489
- "rng-tags",
490
- "thematic_analysis",
491
- "stylistic_comparison",
492
- "narrative_suggestion",
493
- "contextual_storytelling",
494
- "style_prompt"
495
- ],
496
- label="Output Type",
497
  value="descriptive",
498
  )
499
 
500
- gr.Markdown("""
501
- ### Output Types Explained:
502
- - **Descriptive**: A general description of the image
503
- - **Training Prompt**: A prompt for AI image generation
504
- - **RNG-Tags**: Tags for categorizing the image
505
- - **Thematic Analysis**: Exploration of themes in the image
506
- - **Stylistic Comparison**: Compares the image to art styles
507
- - **Narrative Suggestion**: A story idea based on the image
508
- - **Contextual Storytelling**: A background story for the image
509
- - **Style Prompt**: Analyzes the image in context of a specific art style
510
- """)
511
-
512
  caption_tone = gr.Dropdown(
513
  choices=["formal", "informal"],
514
- label="Tone",
515
  value="formal",
516
  )
517
 
518
- gr.Markdown("Choose between a formal (professional) or informal (casual) tone for the output.")
519
-
520
  caption_length = gr.Dropdown(
521
  choices=["any", "very short", "short", "medium-length", "long", "very long"] +
522
  [str(i) for i in range(20, 261, 10)],
523
- label="Length",
524
  value="any",
525
  )
526
 
527
- gr.Markdown("""
528
- Select the desired length of the output:
529
- - 'any': No specific length
530
- - Descriptive options: very short to very long
531
- - Numeric options: Specify exact word count (20 to 260 words)
532
- """)
533
-
534
- art_style = gr.Dropdown(
535
- choices=ART_STYLES,
536
- label="Art Style (for Style Prompt)",
537
- value="Impressionism",
538
- visible=False
539
- )
540
-
541
- gr.Markdown("Select an art style to analyze the image in that context. Only applicable for 'Style Prompt' output type.")
542
 
543
- with gr.Column(scale=1):
544
- output_caption = gr.Textbox(label="Generated Output", lines=10)
545
- generate_button = gr.Button("Generate Caption")
 
546
 
547
  gr.Markdown("""
548
- ### Additional Notes:
549
- - The 'Tone' setting doesn't affect 'RNG-Tags' and 'Training Prompt' outputs.
550
- - 'Art Style' is only used when 'Style Prompt' is selected as the output type.
551
- - The AI model analyzes the image and generates text based on your selections.
 
 
 
552
  """)
553
 
554
- run_button = gr.Button("Caption")
555
-
556
- with gr.Column():
557
- output_caption = gr.Textbox(label="Caption")
558
-
559
-
560
- caption_type.change(
561
- fn=lambda x: gr.update(visible=(x == "style_prompt")),
562
- inputs=[caption_type],
563
- outputs=[art_style]
564
- )
565
-
566
- generate_button.click(
567
- fn=stream_chat,
568
- inputs=[input_image, caption_type, caption_tone, caption_length, art_style],
569
- outputs=[output_caption]
570
- )
571
-
572
- run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, art_style], outputs=[output_caption])
573
-
574
 
575
  if __name__ == "__main__":
576
  demo.launch()
 
261
  image_adapter.eval()
262
  image_adapter.to("cuda")
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  @spaces.GPU()
266
  @torch.no_grad()
267
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
268
+ torch.cuda.empty_cache()
269
+
270
+ # 'any' means no length specified
271
+ length = None if caption_length == "any" else caption_length
272
+
273
+ if isinstance(length, str):
274
+ try:
275
+ length = int(length)
276
+ except ValueError:
277
+ pass
278
+
279
+ # 'rng-tags' and 'training_prompt' don't have formal/informal tones
280
+ if caption_type == "rng-tags" or caption_type == "training_prompt":
281
+ caption_tone = "formal"
282
+
283
+ # Build prompt
284
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
285
+ if prompt_key not in CAPTION_TYPE_MAP:
286
+ raise ValueError(f"Invalid caption type: {prompt_key}")
287
+
288
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
289
+ print(f"Prompt: {prompt_str}")
290
+
291
+ # Preprocess image
292
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
293
+ image = input_image.resize((384, 384), Image.LANCZOS)
294
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
295
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
296
+ pixel_values = pixel_values.to('cuda')
297
+
298
+ # Tokenize the prompt
299
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
300
+
301
+ # Embed image
302
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
303
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
304
+ image_features = vision_outputs.hidden_states
305
+ embedded_images = image_adapter(image_features)
306
+ embedded_images = embedded_images.to('cuda')
307
+
308
+ # Embed prompt
309
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
310
+ assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
311
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
312
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
313
+
314
+ # Construct prompts
315
+ inputs_embeds = torch.cat([
316
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
317
+ embedded_images.to(dtype=embedded_bos.dtype),
318
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
319
+ eot_embed.expand(embedded_images.shape[0], -1, -1),
320
+ ], dim=1)
321
+
322
+ input_ids = torch.cat([
323
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
324
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
325
+ prompt,
326
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
327
+ ], dim=1).to('cuda')
328
+ attention_mask = torch.ones_like(input_ids)
329
+
330
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
331
+ #generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
332
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
333
+
334
+ # Trim off the prompt
335
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
336
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
337
+ generate_ids = generate_ids[:, :-1]
338
+
339
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
340
+
341
+ return caption.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  css = """
344
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
 
358
  }
359
  """
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
362
  with gr.Tab("Welcome"):
363
  gr.Markdown(
364
  """
365
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/eO4MsESKd3K99rYiUuled.png" alt="Yamamoto Logo" class="centered-image">
366
 
367
  # 🎨 Yamamoto JoyCaption: AI-Powered Art Inspiration
368
 
369
  ## Accelerate Your Creative Workflow with Intelligent Image Analysis
370
 
371
  This innovative tool empowers Yamamoto's artists to quickly generate descriptive captions,<br>
372
+ training prompts, and tags from existing artwork, fueling the creative process for GenAI models.
373
 
374
  ## 🚀 How It Works:
375
  1. **Upload Your Inspiration**: Drop in an image (e.g., a charcoal horse picture) that embodies your desired style.
 
381
 
382
  with gr.Tab("JoyCaption"):
383
  gr.Markdown("""
384
+ # How to Use JoyCaption
385
 
386
+ Hello, artist! Let's make some fun captions for your pictures. Here's how:
387
 
388
+ 1. **Pick a Picture**: Find a cool picture you want to talk about and upload it.
389
+
390
+ 2. **Choose What You Want**:
391
+ - **Caption Type**:
392
+ * "Descriptive" tells you what's in the picture
393
+ * "Training Prompt" helps computers make similar pictures
394
+ * "RNG-Tags" gives you short words about the picture
395
 
396
+ 3. **Pick a Style** (for "Descriptive" only):
397
+ - "Formal" sounds like a teacher talking
398
+ - "Informal" sounds like a friend chatting
399
+
400
+ 4. **Decide How Long**:
401
+ - "Any" lets the computer decide
402
+ - Or pick a size from "very short" to "very long"
403
+ - You can even choose a specific number of words!
404
+
405
+ 5. **Make the Caption**: Click the "Caption" button and watch the magic happen!
406
+
407
+ Remember, have fun and be creative with your captions!
408
+ """)
409
+
410
  with gr.Row():
411
+ with gr.Column():
412
+ input_image = gr.Image(type="pil", label="Upload Your Picture Here")
413
 
414
  caption_type = gr.Dropdown(
415
+ choices=["descriptive", "training_prompt", "rng-tags"],
416
+ label="What Kind of Caption Do You Want?",
 
 
 
 
 
 
 
 
 
417
  value="descriptive",
418
  )
419
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  caption_tone = gr.Dropdown(
421
  choices=["formal", "informal"],
422
+ label="How Should It Sound? (For 'Descriptive' Only)",
423
  value="formal",
424
  )
425
 
 
 
426
  caption_length = gr.Dropdown(
427
  choices=["any", "very short", "short", "medium-length", "long", "very long"] +
428
  [str(i) for i in range(20, 261, 10)],
429
+ label="How Long Should It Be?",
430
  value="any",
431
  )
432
 
433
+ gr.Markdown("**Friendly Reminder:** The tone (formal/informal) only works for 'Descriptive' captions.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
+ run_button = gr.Button("Make My Caption!")
436
+
437
+ with gr.Column():
438
+ output_caption = gr.Textbox(label="Your Amazing Caption Appears Here")
439
 
440
  gr.Markdown("""
441
+ ## Tips for Great Captions:
442
+ - Try different types to see what you like best
443
+ - Experiment with formal and informal tones for fun variations
444
+ - Adjust the length to get just the right amount of detail
445
+ - If you don't like a caption, just click "Make My Caption!" again for a new one
446
+
447
+ Have a great time captioning your art!
448
  """)
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
  if __name__ == "__main__":
452
  demo.launch()