Spaces:
Sleeping
Sleeping
import subprocess | |
import sys | |
import os | |
# Function to install or reinstall specific packages | |
def install(package): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package]) | |
# Ensure NumPy and pandas are compatible | |
try: | |
import numpy as np | |
import pandas as pd | |
if not (np.__version__.startswith("1.24")): | |
print(f"Detected incompatible versions. Reinstalling NumPy...") | |
install("numpy==1.24.3") | |
except ImportError: | |
print("NumPy or pandas not found. Installing compatible versions...") | |
install("numpy==1.24.3") | |
# Ensure other dependencies are installed with specific versions | |
try: | |
import torch | |
import torchvision | |
except ImportError: | |
install("torch==2.0.1") | |
install("torchvision==0.15.2") | |
try: | |
from PIL import Image | |
except ImportError: | |
install("Pillow==9.5.0") | |
try: | |
import gradio as gr | |
except ImportError: | |
install("gradio==3.50.2") | |
# Import libraries after ensuring installations | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import gradio as gr | |
# Define the model | |
class ModifiedLargeNet(nn.Module): | |
def __init__(self): | |
super(ModifiedLargeNet, self).__init__() | |
self.name = "modified_large" | |
self.conv1 = nn.Conv2d(3, 5, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(5, 10, 5) | |
self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others" | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 10 * 29 * 29) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
# Load the trained model | |
model = ModifiedLargeNet() | |
model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu"))) | |
model.eval() | |
# Define image transformation pipeline | |
transform = transforms.Compose([ | |
transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
# Prediction function | |
def predict(image): | |
if image is None: | |
raise ValueError("Please provide an image") | |
# Convert to PIL Image if necessary | |
if not isinstance(image, Image.Image): | |
try: | |
image = Image.fromarray(image) | |
except Exception as e: | |
raise ValueError(f"Failed to convert input to PIL Image: {str(e)}") | |
# Transform and predict | |
try: | |
image = transform(image).unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
outputs = model(image) | |
probabilities = torch.softmax(outputs, dim=1).numpy()[0] | |
classes = ["Rope", "Hammer", "Other"] | |
return {cls: float(prob) for cls, prob in zip(classes, probabilities)} | |
except Exception as e: | |
raise ValueError(f"Error during prediction: {str(e)}") | |
# Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(), # Remove type="pil" constraint | |
outputs=gr.Label(num_top_classes=3), | |
title="Mechanical Tools Classifier", | |
description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.", | |
examples=[ | |
["example_rope.jpg"], | |
["example_hammer.jpg"], | |
] if os.path.exists("example_rope.jpg") else None # Optional examples | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() # Removed share=True for Hugging Face Spaces |