import gradio as gr from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration import torch from transformers import BitsAndBytesConfig from PIL import Image import os def load_model(): """Load the model and processor""" repo_name = "ighoshsubho/pali-gamma-finetuned-json" device = "cuda" if torch.cuda.is_available() else "cpu" # Configure quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) # Load processor and model processor = PaliGemmaProcessor.from_pretrained(repo_name) model = PaliGemmaForConditionalGeneration.from_pretrained( repo_name, quantization_config=quantization_config, device_map=device, torch_dtype=torch.bfloat16 if device == "cuda" else None ) return model, processor # Load model globally print("Loading model...") model, processor = load_model() print("Model loaded successfully!") def process_image(image, prompt): """Process the image and return the model's output""" try: # Ensure image is in PIL format if not isinstance(image, Image.Image): image = Image.open(image) # Prepare inputs inputs = processor( text=[f"{prompt}"], images=[image], return_tensors="pt", padding="longest" ).to(model.device) # Generate output outputs = model.generate( **inputs, max_length=512, num_beams=5, temperature=0.7 ) # Decode output result = processor.decode(outputs[0], skip_special_tokens=True) return result except Exception as e: return f"Error processing image: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", value="extract data in JSON format" ) ], outputs=gr.Textbox(label="Generated Output"), title="PaLI-GAMMA Image Analysis", description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.", ) if __name__ == "__main__": demo.launch()