import spaces
import gradio as gr
from huggingface_hub import InferenceClient
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from pathlib import Path
import torch
import torch.amp.autocast_mode
from PIL import Image
import os
import torchvision.transforms.functional as TVF
CLIP_PATH = "google/siglip-so400m-patch14-384"
MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct"
CHECKPOINT_PATH = Path("9em124t2-499968")
CAPTION_TYPE_MAP = {
("descriptive", "formal", False, False): [
"Write a detailed, formal description of this image, focusing on composition, style, and artistic elements.",
"Provide a comprehensive, academic analysis of this artwork's visual characteristics and techniques."
],
("descriptive", "formal", False, True): [
"Craft a formal, concise description of this image within {word_count} words, highlighting key visual elements.",
"Summarize the artwork's main features and style in a formal tone, using no more than {word_count} words."
],
("descriptive", "formal", True, False): [
"Compose a {length} formal critique of this image, discussing its artistic merits and visual impact.",
"Create a {length} scholarly description of this artwork, analyzing its composition and aesthetic qualities."
],
("descriptive", "informal", False, False): [
"Describe this image as if you're explaining it to a friend, focusing on what stands out to you.",
"Give a casual, conversational rundown of what you see in this artwork and how it makes you feel."
],
("descriptive", "informal", False, True): [
"In about {word_count} words, give a laid-back description of this image's vibe and key features.",
"Summarize the coolest parts of this artwork in a casual tone, using roughly {word_count} words."
],
("descriptive", "informal", True, False): [
"Write a {length} chill description of this image, highlighting what you find most interesting or unique.",
"Give a {length} relaxed explanation of what's going on in this artwork and why it catches your eye."
],
("training_prompt", "formal", False, False): [
"Generate a detailed stable diffusion prompt to recreate this image, including style, composition, and key elements.",
"Craft a comprehensive prompt for an AI art generator to produce an image in the same style and mood as this artwork."
],
("training_prompt", "formal", False, True): [
"Within {word_count} words, create a precise stable diffusion prompt capturing the essence of this image.",
"Compose a concise AI art prompt of {word_count} words to replicate this artwork's style and content."
],
("training_prompt", "formal", True, False): [
"Write a {length} stable diffusion prompt that thoroughly describes this image's style, subject, and artistic techniques.",
"Develop a {length} detailed prompt for AI art generation, breaking down the key visual elements and artistic approach of this image."
],
("rng-tags", "formal", False, False): [
"Generate a comprehensive list of Booru tags describing this image's content, style, and artistic elements.",
"Create an extensive set of Booru tags covering all aspects of this artwork, including subject, technique, and mood."
],
("rng-tags", "formal", False, True): [
"Produce a focused list of Booru tags within {word_count} words, capturing the most important aspects of this image.",
"Compile a concise set of Booru tags, limited to {word_count} words, that best represent this artwork's key features."
],
("rng-tags", "formal", True, False): [
"Generate a {length} list of Booru tags, providing a thorough categorization of this image's content and style.",
"Create a {length} set of Booru tags that extensively describe all visual elements and artistic choices in this artwork."
],
("artistic_inspiration", "formal", False, False): [
"Analyze this image and suggest artistic variations or extensions that could be created based on its style and theme.",
"Provide a formal interpretation of this artwork's mood and style, offering ideas for complementary pieces or a series."
],
("artistic_inspiration", "informal", False, False): [
"Brainstorm some cool ideas for new artworks inspired by this image's style or subject matter.",
"Riff on this artwork's vibe and come up with some creative spin-offs or related pieces an artist could make."
],
("technical_breakdown", "formal", False, False): [
"Provide a detailed technical analysis of the artistic techniques and materials likely used to create this image.",
"Break down the compositional elements and artistic methods employed in this artwork, suitable for an art student's study."
],
("emotional_response", "informal", False, False): [
"Describe the emotions and feelings this artwork evokes, and explain why it might resonate with viewers.",
"Share your gut reaction to this image and speculate on what the artist might have been feeling or thinking."
],
("thematic_analysis", "formal", False, False): [
"Provide an in-depth analysis of the themes presented in this image, exploring the underlying messages and concepts.",
"Analyze the primary and secondary themes of this artwork, discussing their significance and interplay."
],
("thematic_analysis", "formal", False, True): [
"Within {word_count} words, dissect the main themes of this image, highlighting their relevance and impact.",
"Craft a concise thematic analysis of this artwork in {word_count} words, focusing on its core messages."
],
("thematic_analysis", "formal", True, False): [
"Write a {length} formal exploration of the themes depicted in this image, examining their depth and meaning.",
"Develop a {length} scholarly analysis of the thematic elements in this artwork, discussing their significance."
],
("stylistic_comparison", "informal", False, False): [
"Compare the style of this image to other famous art movements or artists, highlighting similarities and differences.",
"Describe how this artwork's style relates to [specific artist/style], and what makes it unique."
],
("stylistic_comparison", "informal", False, True): [
"In about {word_count} words, compare this image's style with other known art styles or artists.",
"Summarize the stylistic similarities and differences of this artwork compared to other genres in {word_count} words."
],
("stylistic_comparison", "informal", True, False): [
"Write a {length} casual comparison of this image's style with other art movements or famous artists.",
"Give a {length} relaxed description of how this artwork's style aligns or differs from other genres."
],
("narrative_suggestion", "formal", False, False): [
"Create a short narrative inspired by this image, outlining a possible story that reflects its visual elements.",
"Develop a brief storyline that complements the themes and mood depicted in this artwork."
],
("narrative_suggestion", "formal", False, True): [
"Within {word_count} words, outline a narrative inspired by this image's visual elements and mood.",
"Compose a concise story idea based on the themes and composition of this artwork in {word_count} words."
],
("narrative_suggestion", "formal", True, False): [
"Write a {length} formal narrative inspired by this image, detailing a story that aligns with its visual and thematic elements.",
"Develop a {length} scholarly storyline that reflects the mood and composition of this artwork."
],
("contextual_storytelling", "informal", False, False): [
"Tell a cool story that could be happening in the scene of this image, based on its visual cues.",
"Imagine a background story for this artwork, explaining what's happening and why."
],
("contextual_storytelling", "informal", False, True): [
"In about {word_count} words, create a backstory for the scene depicted in this image.",
"Summarize a possible background narrative for this artwork in {word_count} words."
],
("contextual_storytelling", "informal", True, False): [
"Write a {length} informal story that provides context to the scene portrayed in this image.",
"Give a {length} casual backstory explaining the events depicted in this artwork."
],
("style_prompt", "formal", False, False): [
"Generate a detailed stable diffusion prompt to recreate this image, including style, composition, and key elements.",
"Craft a comprehensive prompt for an AI art generator to produce an image in the same style and mood as this artwork."
],
("style_prompt", "formal", False, True): [
"Within {word_count} words, create a precise stable diffusion prompt capturing the essence of this image.",
"Compose a concise AI art prompt of {word_count} words to replicate this artwork's style and content."
],
("style_prompt", "formal", True, False): [
"Write a {length} stable diffusion prompt that thoroughly describes this image's style, subject, and artistic techniques.",
"Develop a {length} detailed prompt for AI art generation, breaking down the key visual elements and artistic approach of this image."
],
("style_prompt", "informal", False, False): [
"Imagine this image is in an exhibition of {style} art. Describe what makes it fit in or stand out from other {style} pieces.",
"Give a casual rundown of how this artwork vibes with the {style} movement. What's similar? What's different? What's cool about it?"
],
("style_prompt", "informal", False, True): [
"In about {word_count} words, chat about how this image relates to {style} art. What catches your eye as typical or unusual for the style?",
"Summarize in roughly {word_count} words how this artwork plays with {style} ideas. What's familiar? What's a twist on the style?"
],
("style_prompt", "informal", True, False): [
"Write a {length} chill analysis of this image as if it's part of a {style} art show. What works? What's surprising? How does it make you feel?",
"Give a {length} relaxed breakdown of how this artwork fits (or doesn't) into the {style} scene. What's your take on its use of {style} elements?"
],
}
HF_TOKEN = os.environ.get("HF_TOKEN", None)
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
super().__init__()
self.deep_extract = deep_extract
if self.deep_extract:
input_features = input_features * 5
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
# Mode token
#self.mode_token = nn.Embedding(n_modes, output_features)
#self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
# Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
self.other_tokens = nn.Embedding(3, output_features)
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
def forward(self, vision_outputs: torch.Tensor):
if self.deep_extract:
x = torch.concat((
vision_outputs[-2],
vision_outputs[3],
vision_outputs[7],
vision_outputs[13],
vision_outputs[20],
), dim=-1)
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
else:
x = vision_outputs[-2]
x = self.ln1(x)
if self.pos_emb is not None:
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
x = x + self.pos_emb
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
# Mode token
#mode_token = self.mode_token(mode)
#assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
#x = torch.cat((x, mode_token), dim=1)
# <|image_start|>, IMAGE, <|image_end|>
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
return x
def get_eot_embedding(self):
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
# Load CLIP
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
if (CHECKPOINT_PATH / "clip_model.pt").exists():
print("Loading VLM's custom vision model")
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
clip_model.load_state_dict(checkpoint)
del checkpoint
clip_model.eval()
clip_model.requires_grad_(False)
clip_model.to("cuda")
# Tokenizer
print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
# LLM
print("Loading LLM")
if (CHECKPOINT_PATH / "text_model").exists:
print("Loading VLM's custom text model")
text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
else:
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
text_model.eval()
# Image Adapter
print("Loading image adapter")
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
image_adapter.eval()
image_adapter.to("cuda")
@spaces.GPU()
@torch.no_grad()
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
lens_type: str = "standard", film_stock: str = "digital",
composition: str = "rule of thirds", lighting: str = "natural") -> str:
torch.cuda.empty_cache()
# 'any' means no length specified
length = None if caption_length == "any" else caption_length
if isinstance(length, str):
try:
length = int(length)
except ValueError:
pass
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
if caption_type == "rng-tags" or caption_type == "training_prompt":
caption_tone = "formal"
# Build prompt
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
if prompt_key not in CAPTION_TYPE_MAP:
raise ValueError(f"Invalid caption type: {prompt_key}")
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
# Add style prompt details if applicable
if caption_type == "style_prompt":
prompt_str += (f" The prompt should specifically include details about using a {lens_type} lens, "
f"{film_stock} film stock, {composition} composition, and {lighting} lighting. "
f"Format the output as a comma-separated list of descriptors and modifiers, "
f"suitable for direct input into a Stable Diffusion interface.")
print(f"Prompt: {prompt_str}")
# Preprocess image
image = input_image.resize((384, 384), Image.LANCZOS)
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
pixel_values = pixel_values.to('cuda')
# Tokenize the prompt
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
# Embed image
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
image_features = vision_outputs.hidden_states
embedded_images = image_adapter(image_features)
embedded_images = embedded_images.to('cuda')
# Embed prompt
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
# Check if bos_token_id exists
if tokenizer.bos_token_id is None:
print("Warning: bos_token_id is None. Using default value of 1.")
bos_token_id = 1
else:
bos_token_id = tokenizer.bos_token_id
embedded_bos = text_model.model.embed_tokens(torch.tensor([[bos_token_id]], device=text_model.device, dtype=torch.int64))
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
# Construct prompts
inputs_embeds = torch.cat([
embedded_bos.expand(embedded_images.shape[0], -1, -1),
embedded_images.to(dtype=embedded_bos.dtype),
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
eot_embed.expand(embedded_images.shape[0], -1, -1),
], dim=1)
input_ids = torch.cat([
torch.tensor([[bos_token_id]], dtype=torch.long),
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
prompt,
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
], dim=1).to('cuda')
attention_mask = torch.ones_like(input_ids)
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
# Trim off the prompt
generate_ids = generate_ids[:, input_ids.shape[1]:]
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
generate_ids = generate_ids[:, :-1]
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
# For style_prompt, format the output for easy copying into image generation platforms
if caption_type == "style_prompt":
caption = "Stable Diffusion Prompt: " + caption.replace("\n", ", ")
return caption.strip()
css = """
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
text-align: center;
display: block;
margin-left: auto;
margin-right: auto;
}
ul, ol {
margin-left: auto;
margin-right: auto;
display: table;
}
.centered-image {
max-width: 100%;
height: auto;
}
"""
with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
with gr.Tab("Welcome"):
gr.Markdown(
"""
# 🎨 Yamamoto JoyCaption: AI-Powered Art Inspiration
## Accelerate Your Creative Workflow with Intelligent Image Analysis
This innovative tool empowers Yamamoto's artists to quickly generate descriptive captions,
training prompts, and tags from existing artwork, fueling the creative process for GenAI models.
## 🚀 How It Works:
1. **Upload Your Inspiration**: Drop in an image (e.g., a charcoal horse picture) that embodies your desired style.
2. **Choose Your Output**: Select from descriptive captions, training prompts, or tags.
3. **Customize the Results**: Adjust tone, length, and other parameters to fine-tune the output.
4. **Generate and Iterate**: Click 'Caption' to analyze your image and use the results to inspire new creations.
"""
)
with gr.Tab("JoyCaption"):
with gr.Accordion("How to Use JoyCaption", open=False):
gr.Markdown("""
# How to Use JoyCaption
Hello, artist! Let's make some fun captions for your pictures. Here's how:
1. **Pick a Picture**: Find a cool picture you want to talk about and upload it.
2. **Choose What You Want**:
- **Caption Type**:
* "Descriptive" tells you what's in the picture
* "Training Prompt" helps computers make similar pictures
* "RNG-Tags" gives you short words about the picture
* "Style Prompt" creates detailed prompts for image generation
3. **Pick a Style** (for "Descriptive" and "Style Prompt" only):
- "Formal" sounds like a teacher talking
- "Informal" sounds like a friend chatting
4. **Decide How Long**:
- "Any" lets the computer decide
- Or pick a size from "very short" to "very long"
- You can even choose a specific number of words!
5. **Advanced Options** (for "Style Prompt" only):
- Choose lens type, film stock, composition, and lighting details
6. **Make the Caption**: Click the "Make My Caption!" button and watch the magic happen!
Remember, have fun and be creative with your captions!
## Tips for Great Captions:
- Try different types to see what you like best
- Experiment with formal and informal tones for fun variations
- Adjust the length to get just the right amount of detail
- For "Style Prompt", play with the advanced options for more specific results
- If you don't like a caption, just click "Make My Caption!" again for a new one
Have a great time captioning your art!
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Your Picture Here")
caption_type = gr.Dropdown(
choices=["descriptive", "training_prompt", "rng-tags", "style_prompt"],
label="What Kind of Caption Do You Want?",
value="descriptive",
)
caption_tone = gr.Dropdown(
choices=["formal", "informal"],
label="How Should It Sound? (For 'Descriptive' and 'Style Prompt' Only)",
value="formal",
)
caption_length = gr.Dropdown(
choices=["any", "very short", "short", "medium-length", "long", "very long"] +
[str(i) for i in range(20, 261, 10)],
label="How Long Should It Be?",
value="any",
)
with gr.Accordion("Advanced Options (for Style Prompt)", open=False):
lens_type = gr.Dropdown(
choices=["wide-angle", "telephoto", "macro", "fisheye", "standard"],
label="Lens Type",
value="standard",
)
film_stock = gr.Dropdown(
choices=["35mm", "medium format", "large format", "digital"],
label="Film Stock",
value="digital",
)
composition = gr.Dropdown(
choices=["rule of thirds", "golden ratio", "symmetrical", "asymmetrical", "centered"],
label="Composition",
value="rule of thirds",
)
lighting = gr.Dropdown(
choices=["natural", "studio", "high-key", "low-key", "dramatic"],
label="Lighting",
value="natural",
)
gr.Markdown("**Friendly Reminder:** The tone and advanced options only work for specific caption types.")
run_button = gr.Button("Make My Caption!")
with gr.Column():
output_caption = gr.Textbox(label="Your Image Generation Prompt (Copy this for Stable Diffusion)", lines=10)
gr.Markdown("""
## How to Use Your Generated Prompt:
1. For "Style Prompt" captions, the output is formatted for direct use in Stable Diffusion.
2. Simply copy the entire text from the output box.
3. Paste it into your preferred Stable Diffusion interface or any other AI image generation platform.
4. Adjust or add to the prompt as desired to fine-tune your image generation.
Remember, you can always regenerate or modify the prompt to get different results!
""")
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition, lighting], outputs=[output_caption])
if __name__ == "__main__":
demo.launch()