Spaces:
Sleeping
Sleeping
File size: 4,694 Bytes
9f22b53 d06517d 9f22b53 9b05987 4348a2e 9b05987 4348a2e 970edfa 9f22b53 62fa3ff 970edfa 757e51c 7550275 970edfa 757e51c 7550275 970edfa 4348a2e 970edfa 757e51c 9c80e2d 6787f69 970edfa 9f22b53 4348a2e 9f22b53 9c80e2d 9f22b53 8342f3a 970edfa 9f22b53 8342f3a 9f22b53 9c80e2d 9f22b53 d16040b 9f22b53 d16040b 9c80e2d 9f22b53 970edfa d06517d 970edfa 9c80e2d d16040b 9c80e2d 7545949 9c80e2d d16040b 9c80e2d 7545949 9c80e2d 9f22b53 970edfa d06517d d16040b 9c80e2d 970edfa d16040b 7545949 d16040b 7545949 970edfa d16040b d06517d 970edfa 9c80e2d 970edfa 9c80e2d 970edfa d16040b 970edfa 9c80e2d 970edfa d06517d 970edfa 9c80e2d 970edfa 9f22b53 d16040b 9f22b53 7545949 d06517d 9f22b53 d16040b 9f22b53 970edfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import subprocess
import sys
import os
# Function to install or reinstall specific packages
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])
# First, ensure NumPy is installed with the correct version
try:
import numpy as np
if not np.__version__.startswith("1.24"):
print("Installing compatible NumPy version...")
install("numpy==1.24.3")
except ImportError:
print("NumPy not found. Installing...")
install("numpy==1.24.3")
# Then install other dependencies
packages = {
"torch": "2.0.1",
"torchvision": "0.15.2",
"Pillow": "9.5.0",
"gradio": "3.50.2"
}
for package, version in packages.items():
try:
__import__(package.lower())
except ImportError:
print(f"Installing {package}...")
install(f"{package}=={version}")
import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
# Define the model exactly as in training
class ModifiedLargeNet(nn.Module):
def __init__(self):
super(ModifiedLargeNet, self).__init__()
self.name = "modified_large"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 29 * 29, 32)
self.fc2 = nn.Linear(32, 3)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 29 * 29)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load the trained model with error handling
try:
model = ModifiedLargeNet()
state_dict = torch.load("modified_large_net.pt", map_location=torch.device("cpu"))
model.load_state_dict(state_dict)
print("Model loaded successfully")
model.eval()
except Exception as e:
print(f"Error loading model: {str(e)}")
traceback.print_exc()
# Define image transformation pipeline
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.PILToTensor(), # Changed from ToTensor()
transforms.ConvertImageDtype(torch.float32),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def process_image(image):
if image is None:
return None
try:
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
print(f"Processed image size: {image.size}")
print(f"Processed image mode: {image.mode}")
return image
except Exception as e:
print(f"Error in process_image: {str(e)}")
traceback.print_exc()
return None
def predict(image):
if image is None:
return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
try:
# Process the image
processed_image = process_image(image)
if processed_image is None:
return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
# Transform image to tensor using torchvision transforms
try:
tensor_image = transform(processed_image).unsqueeze(0)
print(f"Input tensor shape: {tensor_image.shape}")
except Exception as e:
print(f"Error in tensor conversion: {str(e)}")
traceback.print_exc()
return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
# Make prediction
with torch.no_grad():
outputs = model(tensor_image)
print(f"Raw outputs: {outputs}")
probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
print(f"Probabilities: {probabilities}")
# Return results
classes = ["Rope", "Hammer", "Other"]
results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
print(f"Final results: {results}")
return results
except Exception as e:
print(f"Prediction error: {str(e)}")
traceback.print_exc()
return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=3),
title="Mechanical Tools Classifier",
description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
)
# Launch the interface
if __name__ == "__main__":
interface.launch() |