File size: 1,082 Bytes
5da8948 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import gradio as gr
# Load the pretrained model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
# Define the function to classify images
def classify_image(image):
image = Image.fromarray(image).convert("RGB") # Convert input image to RGB
inputs = feature_extractor(images=image, return_tensors="pt") # Preprocess image
outputs = model(**inputs) # Get model predictions
predicted_class_idx = outputs.logits.argmax(-1).item() # Get predicted class index
return model.config.id2label[predicted_class_idx] # Return class label
# Create a Gradio app interface
app = gr.Interface(
fn=classify_image, # Function to run
inputs=gr.Image(type="numpy"), # Input: Image
outputs="text", # Output: Predicted class label
title="Image Classification App" # App title
)
# Launch the app
app.launch() |