Severian commited on
Commit
d2c00ac
·
verified ·
1 Parent(s): 7c751fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -105
app.py CHANGED
@@ -12,24 +12,28 @@ import torchvision.transforms.functional as TVF
12
 
13
 
14
  CLIP_PATH = "google/siglip-so400m-patch14-384"
15
- MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct"
16
  CHECKPOINT_PATH = Path("9em124t2-499968")
 
17
  CAPTION_TYPE_MAP = {
18
- ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
19
- ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
20
- ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
21
- ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
22
- ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
23
- ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
24
- ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
25
- ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
26
- ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
27
- ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
28
- ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
29
- ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
30
- ("style_prompt", "formal", False, False): ["Generate a detailed stable diffusion prompt to recreate this image, including style, composition, and key elements."],
31
- ("style_prompt", "formal", False, True): ["Within {word_count} words, create a precise stable diffusion prompt capturing the essence of this image."],
32
- ("style_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt that thoroughly describes this image's style, subject, and artistic techniques."],
 
 
 
33
  }
34
 
35
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -139,91 +143,87 @@ image_adapter.eval()
139
  image_adapter.to("cuda")
140
 
141
 
142
- @spaces.GPU()
143
- @torch.no_grad()
144
- def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int,
145
- lens_type: str = "standard", film_stock: str = "digital",
146
- composition: str = "rule of thirds", lighting: str = "natural") -> str:
147
- torch.cuda.empty_cache()
148
-
149
- # 'any' means no length specified
150
- length = None if caption_length == "any" else caption_length
151
-
152
- if isinstance(length, str):
153
- try:
154
- length = int(length)
155
- except ValueError:
156
- pass
157
-
158
- # 'rng-tags', 'training_prompt', and 'style_prompt' don't have formal/informal tones
159
- if caption_type in ["rng-tags", "training_prompt", "style_prompt"]:
160
- caption_tone = "formal"
161
-
162
- # Build prompt
163
- prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
164
- if prompt_key not in CAPTION_TYPE_MAP:
165
- raise ValueError(f"Invalid caption type: {prompt_key}")
166
-
167
- prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
168
-
169
- if caption_type == "style_prompt":
170
- prompt_str += (f" Include details about using a {lens_type} lens, "
171
- f"{film_stock} film stock, {composition} composition, and {lighting} lighting.")
172
-
173
- print(f"Prompt: {prompt_str}")
174
-
175
- # Preprocess image
176
  image = input_image.resize((384, 384), Image.LANCZOS)
177
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
178
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
179
- pixel_values = pixel_values.to('cuda')
180
 
181
- # Tokenize the prompt
 
 
 
182
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
183
-
184
- # Embed image
185
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
186
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
187
- image_features = vision_outputs.hidden_states
188
- embedded_images = image_adapter(image_features)
189
- embedded_images = embedded_images.to('cuda')
190
-
191
- # Embed prompt
192
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
193
- assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
194
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
195
  eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
196
 
197
- # Construct prompts
198
  inputs_embeds = torch.cat([
199
- embedded_bos.expand(embedded_images.shape[0], -1, -1),
200
- embedded_images.to(dtype=embedded_bos.dtype),
201
- prompt_embeds.expand(embedded_images.shape[0], -1, -1),
202
- eot_embed.expand(embedded_images.shape[0], -1, -1),
203
  ], dim=1)
204
 
205
  input_ids = torch.cat([
206
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
207
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
208
  prompt,
209
  torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
210
  ], dim=1).to('cuda')
211
  attention_mask = torch.ones_like(input_ids)
212
 
213
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None)
214
 
215
- # Trim off the prompt
216
  generate_ids = generate_ids[:, input_ids.shape[1]:]
217
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
218
  generate_ids = generate_ids[:, :-1]
219
 
220
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
 
 
 
 
 
 
 
 
221
 
222
- # For style_prompt, format the output for easy copying into image generation platforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if caption_type == "style_prompt":
224
- caption = "Stable Diffusion Prompt: " + caption.replace("\n", ", ")
 
 
225
 
226
- return caption.strip()
 
 
 
 
 
 
 
 
 
 
227
 
228
  css = """
229
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
@@ -243,10 +243,11 @@ ul, ol {
243
  }
244
  """
245
 
 
246
  with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
247
  with gr.Tab("Welcome"):
248
  gr.Markdown(
249
- """
250
  <img src="https://path-to-yamamoto-logo.png" alt="Yamamoto Logo" class="centered-image">
251
 
252
  # 🎨 Yamamoto JoyCaption: AI-Powered Art Inspiration
@@ -263,7 +264,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
263
  4. **Generate and Iterate**: Click 'Caption' to analyze your image and use the results to inspire new creations.
264
  """
265
  )
266
-
267
  with gr.Tab("JoyCaption"):
268
  with gr.Accordion("How to Use JoyCaption", open=False):
269
  gr.Markdown("""
@@ -308,58 +309,68 @@ with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
308
 
309
  with gr.Row():
310
  with gr.Column():
311
- input_image = gr.Image(type="pil", label="Upload Your Picture Here")
312
 
313
  caption_type = gr.Dropdown(
314
  choices=["descriptive", "training_prompt", "rng-tags", "style_prompt"],
315
- label="What Kind of Caption Do You Want?",
316
  value="descriptive",
317
  )
318
 
319
  caption_tone = gr.Dropdown(
320
  choices=["formal", "informal"],
321
- label="How Should It Sound? (For 'Descriptive' and 'Style Prompt' Only)",
322
  value="formal",
323
  )
324
 
325
  caption_length = gr.Dropdown(
326
  choices=["any", "very short", "short", "medium-length", "long", "very long"] +
327
  [str(i) for i in range(20, 261, 10)],
328
- label="How Long Should It Be?",
329
  value="any",
330
  )
331
 
332
- with gr.Accordion("Advanced Options (for Style Prompt)", open=False):
333
- lens_type = gr.Dropdown(
334
- choices=["wide-angle", "telephoto", "macro", "fisheye", "standard"],
335
- label="Lens Type",
336
- value="standard",
337
- )
338
- film_stock = gr.Dropdown(
339
- choices=["35mm", "medium format", "large format", "digital"],
340
- label="Film Stock",
341
- value="digital",
342
- )
343
- composition = gr.Dropdown(
344
- choices=["rule of thirds", "golden ratio", "symmetrical", "asymmetrical", "centered"],
345
- label="Composition",
346
- value="rule of thirds",
347
- )
348
- lighting = gr.Dropdown(
349
- choices=["natural", "studio", "high-key", "low-key", "dramatic"],
350
- label="Lighting",
351
- value="natural",
352
- )
353
-
354
- gr.Markdown("**Friendly Reminder:** The tone and advanced options only work for specific caption types.")
355
 
356
  run_button = gr.Button("Make My Caption!")
357
 
358
  with gr.Column():
359
- output_caption = gr.Textbox(label="Your Amazing Caption Appears Here", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
- run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition, lighting], outputs=[output_caption])
 
362
 
 
363
 
364
  if __name__ == "__main__":
365
  demo.launch()
 
12
 
13
 
14
  CLIP_PATH = "google/siglip-so400m-patch14-384"
15
+ MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
16
  CHECKPOINT_PATH = Path("9em124t2-499968")
17
+ TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
18
  CAPTION_TYPE_MAP = {
19
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
20
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
21
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
22
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
23
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
24
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
25
+
26
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
27
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
28
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
29
+
30
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
31
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
32
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
33
+
34
+ ("style_prompt", "formal", False, False): ["Generate a detailed style prompt for this image, including lens type, film stock, composition notes, and lighting aspects."],
35
+ ("style_prompt", "formal", False, True): ["Generate a detailed style prompt for this image within {word_count} words, including lens type, film stock, composition notes, and lighting aspects."],
36
+ ("style_prompt", "formal", True, False): ["Generate a {length} detailed style prompt for this image, including lens type, film stock, composition notes, and lighting aspects."],
37
  }
38
 
39
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
143
  image_adapter.to("cuda")
144
 
145
 
146
+ def preprocess_image(input_image: Image.Image) -> torch.Tensor:
147
+ """
148
+ Preprocess the input image for the CLIP model.
149
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  image = input_image.resize((384, 384), Image.LANCZOS)
151
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
152
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
153
+ return pixel_values.to('cuda')
154
 
155
+ def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
156
+ """
157
+ Generate a caption based on the image features and prompt.
158
+ """
159
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
 
 
 
 
 
 
 
 
 
160
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
 
161
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
162
  eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
163
 
 
164
  inputs_embeds = torch.cat([
165
+ embedded_bos.expand(image_features.shape[0], -1, -1),
166
+ image_features.to(dtype=embedded_bos.dtype),
167
+ prompt_embeds.expand(image_features.shape[0], -1, -1),
168
+ eot_embed.expand(image_features.shape[0], -1, -1),
169
  ], dim=1)
170
 
171
  input_ids = torch.cat([
172
  torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
173
+ torch.zeros((1, image_features.shape[1]), dtype=torch.long),
174
  prompt,
175
  torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
176
  ], dim=1).to('cuda')
177
  attention_mask = torch.ones_like(input_ids)
178
 
179
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
180
 
 
181
  generate_ids = generate_ids[:, input_ids.shape[1]:]
182
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
183
  generate_ids = generate_ids[:, :-1]
184
 
185
+ return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
186
+
187
+ @spaces.GPU()
188
+ @torch.no_grad()
189
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, lens_type: str = "", film_stock: str = "", composition_style: str = "") -> str:
190
+ """
191
+ Generate a caption or style prompt based on the input image and parameters.
192
+ """
193
+ torch.cuda.empty_cache()
194
 
195
+ try:
196
+ length = None if caption_length == "any" else caption_length
197
+ if isinstance(length, str):
198
+ length = int(length)
199
+ except ValueError:
200
+ raise ValueError(f"Invalid caption length: {caption_length}")
201
+
202
+ if caption_type in ["rng-tags", "training_prompt", "style_prompt"]:
203
+ caption_tone = "formal"
204
+
205
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
206
+ if prompt_key not in CAPTION_TYPE_MAP:
207
+ raise ValueError(f"Invalid caption type: {prompt_key}")
208
+
209
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
210
+
211
  if caption_type == "style_prompt":
212
+ prompt_str += f" Lens type: {lens_type}. Film stock: {film_stock}. Composition style: {composition_style}."
213
+
214
+ print(f"Prompt: {prompt_str}")
215
 
216
+ pixel_values = preprocess_image(input_image)
217
+
218
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
219
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
220
+ image_features = vision_outputs.hidden_states
221
+ embedded_images = image_adapter(image_features)
222
+ embedded_images = embedded_images.to('cuda')
223
+
224
+ caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
225
+
226
+ return caption
227
 
228
  css = """
229
  h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
 
243
  }
244
  """
245
 
246
+ # Gradio interface
247
  with gr.Blocks(theme="Hev832/Applio", css=css) as demo:
248
  with gr.Tab("Welcome"):
249
  gr.Markdown(
250
+ """
251
  <img src="https://path-to-yamamoto-logo.png" alt="Yamamoto Logo" class="centered-image">
252
 
253
  # 🎨 Yamamoto JoyCaption: AI-Powered Art Inspiration
 
264
  4. **Generate and Iterate**: Click 'Caption' to analyze your image and use the results to inspire new creations.
265
  """
266
  )
267
+
268
  with gr.Tab("JoyCaption"):
269
  with gr.Accordion("How to Use JoyCaption", open=False):
270
  gr.Markdown("""
 
309
 
310
  with gr.Row():
311
  with gr.Column():
312
+ input_image = gr.Image(type="pil", label="Input Image")
313
 
314
  caption_type = gr.Dropdown(
315
  choices=["descriptive", "training_prompt", "rng-tags", "style_prompt"],
316
+ label="Caption Type",
317
  value="descriptive",
318
  )
319
 
320
  caption_tone = gr.Dropdown(
321
  choices=["formal", "informal"],
322
+ label="Caption Tone",
323
  value="formal",
324
  )
325
 
326
  caption_length = gr.Dropdown(
327
  choices=["any", "very short", "short", "medium-length", "long", "very long"] +
328
  [str(i) for i in range(20, 261, 10)],
329
+ label="Caption Length",
330
  value="any",
331
  )
332
 
333
+ lens_type = gr.Dropdown(
334
+ choices=["Wide-angle", "Standard", "Telephoto", "Macro", "Fish-eye"],
335
+ label="Lens Type",
336
+ visible=False,
337
+ )
338
+
339
+ film_stock = gr.Dropdown(
340
+ choices=["Kodak Portra", "Fujifilm Velvia", "Ilford Delta", "Kodak Tri-X", "Fujifilm Provia"],
341
+ label="Film Stock",
342
+ visible=False,
343
+ )
344
+
345
+ composition_style = gr.Dropdown(
346
+ choices=["Rule of Thirds", "Golden Ratio", "Symmetry", "Leading Lines", "Framing"],
347
+ label="Composition Style",
348
+ visible=False,
349
+ )
350
+
351
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags`, `training_prompt`, and `style_prompt`.")
 
 
 
 
352
 
353
  run_button = gr.Button("Make My Caption!")
354
 
355
  with gr.Column():
356
+ output_caption = gr.Textbox(label="Generated Caption")
357
+ copy_button = gr.Button("Copy to Clipboard")
358
+
359
+ def update_style_options(caption_type):
360
+ return {
361
+ lens_type: gr.update(visible=caption_type == "style_prompt"),
362
+ film_stock: gr.update(visible=caption_type == "style_prompt"),
363
+ composition_style: gr.update(visible=caption_type == "style_prompt"),
364
+ }
365
+
366
+ caption_type.change(update_style_options, inputs=[caption_type], outputs=[lens_type, film_stock, composition_style])
367
+
368
+ run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style], outputs=[output_caption])
369
 
370
+ def copy_to_clipboard():
371
+ return None
372
 
373
+ copy_button.click(fn=copy_to_clipboard, inputs=[], outputs=[])
374
 
375
  if __name__ == "__main__":
376
  demo.launch()