Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
from huggingface_hub import login | |
import os | |
# Retrieve and use API token from environment variables | |
token = os.getenv("access_token") | |
if token: | |
login(token=token, add_to_git_credential=True) | |
# Load model and processor | |
model_id = "google/paligemma-3b-mix-224" | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval() | |
processor = AutoProcessor.from_pretrained(model_id) | |
def generate_conversational_response(image, user_input): | |
# Ensure the image is in PIL format | |
if not isinstance(image, Image.Image): | |
image = Image.open(image) | |
# Prepare the prompt with the user's input | |
prompt = f"{user_input}" | |
# Process the image and text prompt | |
model_inputs = processor(text=prompt, images=image, return_tensors="pt") | |
input_len = model_inputs["input_ids"].shape[-1] | |
# Generate the response | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=1024, do_sample=False) | |
generation = generation[0][input_len:] | |
decoded = processor.decode(generation, skip_special_tokens=True) | |
return decoded | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=generate_conversational_response, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), # Allows users to upload local images | |
gr.Textbox(lines=2, placeholder="Enter your question or starting input here", label="Starting Input") | |
], | |
outputs="text", | |
title="Image-Based Conversational AI", | |
description="Upload an image from your local system and provide a starting input. The model will generate a caption and respond to your query based on the image." | |
) | |
# Launch the interface | |
interface.launch() | |