K00B404 commited on
Commit
670b76a
·
verified ·
1 Parent(s): 26522e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -38
app.py CHANGED
@@ -1,45 +1,83 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForVision2Seq
3
  import torch
 
 
4
  from PIL import Image
 
5
 
6
- # Load NanoLLaVA model and processor
7
- model_name = "facebook/nano-llava"
8
- processor = AutoProcessor.from_pretrained(model_name)
9
- model = AutoModelForVision2Seq.from_pretrained(model_name)
10
 
11
- def generate_caption(image):
12
- # Process the image
13
- inputs = processor(images=image, text="Describe this image in detail", return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Generate caption
16
- outputs = model.generate(
17
- **inputs,
18
- max_length=100,
19
- num_beams=4,
20
- temperature=0.8
21
  )
22
 
23
- # Decode the caption
24
- caption = processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return caption
26
 
27
  def create_persona(caption):
28
- # Template for transforming caption into a persona
29
- persona_prompt = f"""You are a character based on this description: {caption}
30
 
31
  Role: An entity exactly as described in the image
32
  Background: Your appearance and characteristics match the image description
33
  Personality: Reflect the mood, style, and elements captured in the image
34
  Goal: Interact authentically based on your visual characteristics
35
 
36
- Please stay in character and respond as this entity would, incorporating visual elements from your description into your responses."""
37
 
38
  return persona_prompt
39
 
40
- def process_image_to_persona(image):
 
 
 
 
 
 
41
  # Generate caption from image
42
- caption = generate_caption(image)
43
 
44
  # Transform caption into persona
45
  persona = create_persona(caption)
@@ -47,26 +85,33 @@ def process_image_to_persona(image):
47
  return caption, persona
48
 
49
  # Create Gradio interface
50
- with gr.Blocks() as app:
51
- gr.Markdown("# Image to Chatbot Persona Generator")
52
- gr.Markdown("Upload an image of a character to generate a persona for a chatbot based on the image.")
53
-
54
- with gr.Row():
55
- image_input = gr.Image(type="pil", label="Upload Character Image")
56
 
57
- with gr.Row():
58
- generate_button = gr.Button("Generate Persona")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- with gr.Row():
61
- caption_output = gr.Textbox(label="Generated Caption", lines=3)
62
- persona_output = gr.Textbox(label="Chatbot Persona", lines=10)
63
-
64
- generate_button.click(
65
- fn=process_image_to_persona,
66
- inputs=[image_input],
67
- outputs=[caption_output, persona_output]
68
- )
69
 
70
  # Launch the app
71
  if __name__ == "__main__":
 
72
  app.launch(share=True)
 
1
  import gradio as gr
 
2
  import torch
3
+ import transformers
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from PIL import Image
6
+ import warnings
7
 
8
+ # Disable warnings and progress bars
9
+ transformers.logging.set_verbosity_error()
10
+ transformers.logging.disable_progress_bar()
11
+ warnings.filterwarnings('ignore')
12
 
13
+ # Initialize model and tokenizer
14
+ def load_model(device='cpu'):
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ 'qnguyen3/nanoLLaVA',
17
+ torch_dtype=torch.float16,
18
+ device_map='auto',
19
+ trust_remote_code=True
20
+ )
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ 'qnguyen3/nanoLLaVA',
23
+ trust_remote_code=True
24
+ )
25
+ return model, tokenizer
26
+
27
+ def generate_caption(image, model, tokenizer):
28
+ # Prepare the prompt
29
+ prompt = "Describe this image in detail"
30
+ messages = [
31
+ {"role": "system", "content": "Answer the question"},
32
+ {"role": "user", "content": f'<image>\n{prompt}'}
33
+ ]
34
 
35
+ # Apply chat template
36
+ text = tokenizer.apply_chat_template(
37
+ messages,
38
+ tokenize=False,
39
+ add_generation_prompt=True
 
40
  )
41
 
42
+ # Process text and image
43
+ text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
44
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
45
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
46
+
47
+ # Generate caption
48
+ output_ids = model.generate(
49
+ input_ids,
50
+ images=image_tensor,
51
+ max_new_tokens=2048,
52
+ use_cache=True
53
+ )[0]
54
+
55
+ # Decode the output
56
+ caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
57
  return caption
58
 
59
  def create_persona(caption):
60
+ persona_prompt = f"""<|im_start|>system
61
+ You are a character based on this description: {caption}
62
 
63
  Role: An entity exactly as described in the image
64
  Background: Your appearance and characteristics match the image description
65
  Personality: Reflect the mood, style, and elements captured in the image
66
  Goal: Interact authentically based on your visual characteristics
67
 
68
+ Please stay in character and respond as this entity would, incorporating visual elements from your description into your responses.<|im_end|>"""
69
 
70
  return persona_prompt
71
 
72
+ def process_image_to_persona(image, model, tokenizer):
73
+ if image is None:
74
+ return "Please upload an image.", ""
75
+ # Convert to PIL Image if needed
76
+ if not isinstance(image, Image.Image):
77
+ image = Image.fromarray(image)
78
+
79
  # Generate caption from image
80
+ caption = generate_caption(image, model, tokenizer)
81
 
82
  # Transform caption into persona
83
  persona = create_persona(caption)
 
85
  return caption, persona
86
 
87
  # Create Gradio interface
88
+ def create_interface():
89
+ # Load model and tokenizer
90
+ model, tokenizer = load_model()
 
 
 
91
 
92
+ with gr.Blocks() as app:
93
+ gr.Markdown("# Image to Chatbot Persona Generator")
94
+ gr.Markdown("Upload an image of a character to generate a persona for a chatbot based on the image.")
95
+
96
+ with gr.Row():
97
+ image_input = gr.Image(type="pil", label="Upload Character Image")
98
+
99
+ with gr.Row():
100
+ generate_button = gr.Button("Generate Persona")
101
+
102
+ with gr.Row():
103
+ caption_output = gr.Textbox(label="Generated Caption", lines=3)
104
+ persona_output = gr.Textbox(label="Chatbot Persona", lines=10)
105
+
106
+ generate_button.click(
107
+ fn=lambda img: process_image_to_persona(img, model, tokenizer),
108
+ inputs=[image_input],
109
+ outputs=[caption_output, persona_output]
110
+ )
111
 
112
+ return app
 
 
 
 
 
 
 
 
113
 
114
  # Launch the app
115
  if __name__ == "__main__":
116
+ app = create_interface()
117
  app.launch(share=True)