import gradio as gr from transformers import AutoFeatureExtractor, AutoModelForImageClassification import torch from PIL import Image # Load pre-trained model and feature extractor model_name = "google/vit-base-patch16-224" feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) def classify_image(image): # Preprocess the image inputs = feature_extractor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) # Get the predicted class predicted_class_idx = outputs.logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] return predicted_class # Create Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Predicted Class"), title="Image Classification", description="Upload an image to classify it using a pre-trained ViT model." ) # Launch the app iface.launch()