import gradio as gr import numpy as np import torch from PIL import Image import matplotlib.pyplot as plt from transformers import AutoFeatureExtractor, AutoModelForImageClassification # Use a smaller, more efficient model model_name = "microsoft/resnet-18" # Smaller model that should work with Hugging Face constraints # Load model and feature extractor feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) # Function to classify image def classify_image(image): if image is None: return "No image provided", None try: # Process image inputs = feature_extractor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] # Get top 5 predictions probs = torch.nn.functional.softmax(logits, dim=-1)[0] top5_prob, top5_indices = torch.topk(probs, 5) # Create plot for visualization fig, ax = plt.subplots(figsize=(10, 5)) # Get class names and probabilities classes = [model.config.id2label[idx.item()] for idx in top5_indices] probabilities = [prob.item() * 100 for prob in top5_prob] # Create horizontal bar chart bars = ax.barh(classes, probabilities, color='#4C72B0') ax.set_xlabel('Probability (%)') ax.set_title('Top 5 Predictions') # Add percentage labels for i, bar in enumerate(bars): width = bar.get_width() ax.text(width + 1, bar.get_y() + bar.get_height()/2, f'{probabilities[i]:.1f}%', va='center', fontsize=10) # Improve layout plt.tight_layout() return predicted_class, fig except Exception as e: return f"Error: {str(e)}", None # Create Gradio interface with simpler structure demo = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=[ gr.Textbox(label="Prediction"), gr.Plot(label="Confidence Levels") ], title="🖼️ Image Classification Tool", description="Upload an image to see what the AI recognizes in it!", allow_flagging="never", examples=[], # No examples to avoid dependencies theme=gr.themes.Soft() ) # Launch the app if __name__ == "__main__": demo.launch()