Assignment1 / app.py
StoneSeller's picture
Update app.py
d06517d verified
raw
history blame
3.55 kB
import subprocess
import sys
import os
# Function to install or reinstall specific packages
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
# Ensure NumPy and pandas are compatible
try:
import numpy as np
import pandas as pd
if not (np.__version__.startswith("1.24")):
print(f"Detected incompatible versions. Reinstalling NumPy...")
install("numpy==1.24.3")
except ImportError:
print("NumPy or pandas not found. Installing compatible versions...")
install("numpy==1.24.3")
# Ensure other dependencies are installed with specific versions
try:
import torch
import torchvision
except ImportError:
install("torch==2.0.1")
install("torchvision==0.15.2")
try:
from PIL import Image
except ImportError:
install("Pillow==9.5.0")
try:
import gradio as gr
except ImportError:
install("gradio==3.50.2")
# Import libraries after ensuring installations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
# Define the model
class ModifiedLargeNet(nn.Module):
def __init__(self):
super(ModifiedLargeNet, self).__init__()
self.name = "modified_large"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 29 * 29, 32)
self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others"
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 29 * 29)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load the trained model
model = ModifiedLargeNet()
model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
model.eval()
# Define image transformation pipeline
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# Prediction function
def predict(image):
if image is None:
raise ValueError("Please provide an image")
# Convert to PIL Image if necessary
if not isinstance(image, Image.Image):
try:
image = Image.fromarray(image)
except Exception as e:
raise ValueError(f"Failed to convert input to PIL Image: {str(e)}")
# Transform and predict
try:
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
probabilities = torch.softmax(outputs, dim=1).numpy()[0]
classes = ["Rope", "Hammer", "Other"]
return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
except Exception as e:
raise ValueError(f"Error during prediction: {str(e)}")
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(), # Remove type="pil" constraint
outputs=gr.Label(num_top_classes=3),
title="Mechanical Tools Classifier",
description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
examples=[
["example_rope.jpg"],
["example_hammer.jpg"],
] if os.path.exists("example_rope.jpg") else None # Optional examples
)
# Launch the interface
if __name__ == "__main__":
interface.launch() # Removed share=True for Hugging Face Spaces