File size: 3,797 Bytes
26522e0
670b76a
2d1cfc3
670b76a
26522e0
670b76a
26522e0
670b76a
 
 
 
26522e0
670b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26522e0
670b76a
 
 
 
 
26522e0
 
670b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26522e0
 
 
670b76a
 
26522e0
 
 
 
 
 
670b76a
26522e0
 
 
670b76a
 
 
 
 
 
 
26522e0
670b76a
26522e0
 
 
 
 
 
 
670b76a
 
 
26522e0
670b76a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26522e0
670b76a
26522e0
 
 
670b76a
26522e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings

# Disable warnings and progress bars
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')

# Initialize model and tokenizer
def load_model(device='cpu'):
    model = AutoModelForCausalLM.from_pretrained(
        'qnguyen3/nanoLLaVA',
        torch_dtype=torch.float16,
        device_map='auto',
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        'qnguyen3/nanoLLaVA',
        trust_remote_code=True
    )
    return model, tokenizer

def generate_caption(image, model, tokenizer):
    # Prepare the prompt
    prompt = "Describe this image in detail"
    messages = [
        {"role": "system", "content": "Answer the question"},
        {"role": "user", "content": f'<image>\n{prompt}'}
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Process text and image
    text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
    input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
    image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
    
    # Generate caption
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        max_new_tokens=2048,
        use_cache=True
    )[0]
    
    # Decode the output
    caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
    return caption

def create_persona(caption):
    persona_prompt = f"""<|im_start|>system
You are a character based on this description: {caption}

Role: An entity exactly as described in the image
Background: Your appearance and characteristics match the image description
Personality: Reflect the mood, style, and elements captured in the image
Goal: Interact authentically based on your visual characteristics

Please stay in character and respond as this entity would, incorporating visual elements from your description into your responses.<|im_end|>"""
    
    return persona_prompt

def process_image_to_persona(image, model, tokenizer):
    if image is None:
        return "Please upload an image.", ""
    # Convert to PIL Image if needed
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    # Generate caption from image
    caption = generate_caption(image, model, tokenizer)
    
    # Transform caption into persona
    persona = create_persona(caption)
    
    return caption, persona

# Create Gradio interface
def create_interface():
    # Load model and tokenizer
    model, tokenizer = load_model()
    
    with gr.Blocks() as app:
        gr.Markdown("# Image to Chatbot Persona Generator")
        gr.Markdown("Upload an image of a character to generate a persona for a chatbot based on the image.")
        
        with gr.Row():
            image_input = gr.Image(type="pil", label="Upload Character Image")
        
        with gr.Row():
            generate_button = gr.Button("Generate Persona")
        
        with gr.Row():
            caption_output = gr.Textbox(label="Generated Caption", lines=3)
            persona_output = gr.Textbox(label="Chatbot Persona", lines=10)
        
        generate_button.click(
            fn=lambda img: process_image_to_persona(img, model, tokenizer),
            inputs=[image_input],
            outputs=[caption_output, persona_output]
        )
    
    return app

# Launch the app
if __name__ == "__main__":
    app = create_interface()
    app.launch(share=True)