File size: 4,694 Bytes
9f22b53
 
d06517d
9f22b53
9b05987
4348a2e
9b05987
4348a2e
970edfa
9f22b53
62fa3ff
970edfa
 
757e51c
7550275
970edfa
757e51c
7550275
970edfa
 
 
 
 
 
 
4348a2e
970edfa
 
 
 
 
 
757e51c
9c80e2d
6787f69
970edfa
9f22b53
 
4348a2e
9f22b53
 
 
 
9c80e2d
9f22b53
 
 
 
8342f3a
 
 
 
970edfa
9f22b53
 
8342f3a
 
 
 
 
9f22b53
 
9c80e2d
 
 
 
 
 
 
 
 
 
9f22b53
d16040b
9f22b53
 
d16040b
 
9c80e2d
9f22b53
 
970edfa
d06517d
970edfa
 
9c80e2d
d16040b
9c80e2d
7545949
9c80e2d
d16040b
9c80e2d
 
7545949
9c80e2d
 
 
 
 
 
 
 
9f22b53
970edfa
 
 
 
d06517d
d16040b
9c80e2d
970edfa
 
 
d16040b
7545949
d16040b
7545949
 
 
 
 
970edfa
d16040b
d06517d
970edfa
9c80e2d
 
970edfa
9c80e2d
970edfa
d16040b
970edfa
9c80e2d
 
 
970edfa
d06517d
970edfa
9c80e2d
970edfa
9f22b53
d16040b
9f22b53
 
7545949
d06517d
9f22b53
 
 
 
d16040b
9f22b53
970edfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import subprocess
import sys
import os

# Function to install or reinstall specific packages
def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])

# First, ensure NumPy is installed with the correct version
try:
    import numpy as np
    if not np.__version__.startswith("1.24"):
        print("Installing compatible NumPy version...")
        install("numpy==1.24.3")
except ImportError:
    print("NumPy not found. Installing...")
    install("numpy==1.24.3")

# Then install other dependencies
packages = {
    "torch": "2.0.1",
    "torchvision": "0.15.2",
    "Pillow": "9.5.0",
    "gradio": "3.50.2"
}

for package, version in packages.items():
    try:
        __import__(package.lower())
    except ImportError:
        print(f"Installing {package}...")
        install(f"{package}=={version}")


import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr

# Define the model exactly as in training
class ModifiedLargeNet(nn.Module):
    def __init__(self):
        super(ModifiedLargeNet, self).__init__()
        self.name = "modified_large"
        self.conv1 = nn.Conv2d(3, 5, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(5, 10, 5)
        self.fc1 = nn.Linear(10 * 29 * 29, 32)
        self.fc2 = nn.Linear(32, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 10 * 29 * 29)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load the trained model with error handling
try:
    model = ModifiedLargeNet()
    state_dict = torch.load("modified_large_net.pt", map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)
    print("Model loaded successfully")
    model.eval()
except Exception as e:
    print(f"Error loading model: {str(e)}")
    traceback.print_exc()

# Define image transformation pipeline
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.PILToTensor(),  # Changed from ToTensor()
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def process_image(image):
    if image is None:
        return None
    
    try:
        # Convert numpy array to PIL Image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'))
        
        # Convert to RGB if necessary
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        print(f"Processed image size: {image.size}")
        print(f"Processed image mode: {image.mode}")
        
        return image
    except Exception as e:
        print(f"Error in process_image: {str(e)}")
        traceback.print_exc()
        return None

def predict(image):
    if image is None:
        return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
    
    try:
        # Process the image
        processed_image = process_image(image)
        if processed_image is None:
            return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
        
        # Transform image to tensor using torchvision transforms
        try:
            tensor_image = transform(processed_image).unsqueeze(0)
            print(f"Input tensor shape: {tensor_image.shape}")
        except Exception as e:
            print(f"Error in tensor conversion: {str(e)}")
            traceback.print_exc()
            return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
        
        # Make prediction
        with torch.no_grad():
            outputs = model(tensor_image)
            print(f"Raw outputs: {outputs}")
            
            probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
            print(f"Probabilities: {probabilities}")
            
        # Return results
        classes = ["Rope", "Hammer", "Other"]
        results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
        print(f"Final results: {results}")
        return results
    
    except Exception as e:
        print(f"Prediction error: {str(e)}")
        traceback.print_exc()
        return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}

# Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=3),
    title="Mechanical Tools Classifier",
    description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()