import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from huggingface_hub import login
import os

# Retrieve and use API token from environment variables
token = os.getenv("access_token")
if token:
    login(token=token, add_to_git_credential=True)

# Load model and processor
model_id = "google/paligemma-3b-mix-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)

def generate_conversational_response(image, user_input):
    # Ensure the image is in PIL format
    if not isinstance(image, Image.Image):
        image = Image.open(image)
    
    # Prepare the prompt with the user's input
    prompt = f"{user_input}"
    
    # Process the image and text prompt
    model_inputs = processor(text=prompt, images=image, return_tensors="pt")
    input_len = model_inputs["input_ids"].shape[-1]

    # Generate the response
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=1024, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    
    return decoded

# Set up Gradio interface
interface = gr.Interface(
    fn=generate_conversational_response,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),  # Allows users to upload local images
        gr.Textbox(lines=2, placeholder="Enter your question or starting input here", label="Starting Input")
    ],
    outputs="text",
    title="Image-Based Conversational AI",
    description="Upload an image from your local system and provide a starting input. The model will generate a caption and respond to your query based on the image."
)

# Launch the interface
interface.launch()