import gradio as gr from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel, ViTForImageClassification import torch # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Model checkpoints caption_model_ckpt = "nlpconnect/vit-gpt2-image-captioning" classify_model_ckpt = "google/vit-base-patch16-224" # Load captioning components tokenizer = AutoTokenizer.from_pretrained(caption_model_ckpt) image_processor = AutoImageProcessor.from_pretrained(caption_model_ckpt) caption_model = VisionEncoderDecoderModel.from_pretrained(caption_model_ckpt).to(device) # Load classification model classify_processor = AutoImageProcessor.from_pretrained(classify_model_ckpt) classification_model = ViTForImageClassification.from_pretrained(classify_model_ckpt).to(device) # Captioning function def get_caption(image): if image is None: return "No image uploaded." image = image.convert("RGB") pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(device) output_ids = caption_model.generate(pixel_values, max_length=64, num_beams=4)[0] caption = tokenizer.decode(output_ids, skip_special_tokens=True) return caption # Classification function def classify_image(image): if image is None: return {"Error": "No image uploaded."} image = image.convert("RGB") inputs = classify_processor(images=image, return_tensors="pt").to(device) outputs = classification_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) top_probs, top_labels = torch.topk(probs, 5) results = { classification_model.config.id2label[label.item()]: round(prob.item(), 4) for label, prob in zip(top_labels[0], top_probs[0]) } return results # Gradio app with gr.Blocks(title="Image Captioning and Recognition") as demo: gr.Markdown("# 🖼️ Image Captioning & Classification App") gr.Markdown("Upload an image, then click below to generate a caption or classify it.") image_input = gr.Image(label="Upload Image", type="pil") with gr.Row(): get_caption_btn = gr.Button("📝 Get Caption") classify_btn = gr.Button("🔍 Classify Image") caption_output = gr.Textbox(label="Generated Caption") classification_output = gr.Label(label="Top 5 Predictions") get_caption_btn.click(fn=get_caption, inputs=image_input, outputs=caption_output) classify_btn.click(fn=classify_image, inputs=image_input, outputs=classification_output) demo.launch()