import gradio as gr
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import copy
import torch
import warnings

warnings.filterwarnings("ignore")

pretrained = "AI-Safeguard/Ivy-VL-llava"
model_name = "llava_qwen"
device = "cpu"
device_map = "auto"

# Load model, tokenizer, and image processor
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
model.eval()

def respond(image, question, temperature, max_tokens):
    try:
        # Load and process the image
        image_tensor = process_images([image], image_processor, model.config)
        image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]

        # Prepare the conversation template
        conv_template = "qwen_1_5"
        formatted_question = DEFAULT_IMAGE_TOKEN + "\n" + question
        conv = copy.deepcopy(conv_templates[conv_template])
        conv.append_message(conv.roles[0], formatted_question)
        conv.append_message(conv.roles[1], None)
        prompt_question = conv.get_prompt()

        # Tokenize input
        input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
        image_sizes = [image.size]

        # Generate response
        cont = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=image_sizes,
            do_sample=False,
            temperature=temperature,
            max_new_tokens=max_tokens,
        )

        text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
        return text_outputs[0]
    except Exception as e:
        return f"Error: {str(e)}"

# Gradio Interface
def chat_interface(image, question, temperature, max_tokens):
    if not image or not question:
        return "Please provide both an image and a question."
    return respond(image, question, temperature, max_tokens)

demo = gr.Interface(
    fn=chat_interface,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Textbox(label="Question"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max Tokens"),
    ],
    outputs="text",
    title="AI-Safeguard Ivy-VL-Llava Image Question Answering",
    description="Upload an image and ask a question about it. The model will provide a response based on the visual and textual input."
)

if __name__ == "__main__":
    demo.launch()