File size: 4,738 Bytes
9f22b53
 
d06517d
9f22b53
7077e08
b8cf886
 
 
 
 
 
 
 
 
 
 
7077e08
b8cf886
9c80e2d
6787f69
970edfa
9f22b53
 
4348a2e
9f22b53
 
 
 
7077e08
9f22b53
 
 
 
8342f3a
 
 
 
970edfa
9f22b53
 
8342f3a
 
 
 
 
9f22b53
 
7077e08
9c80e2d
 
 
 
 
 
 
 
 
9f22b53
 
 
7077e08
9c80e2d
9f22b53
 
b8cf886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970edfa
d06517d
970edfa
 
9c80e2d
 
b8cf886
9c80e2d
 
 
7545949
6cc530c
 
9c80e2d
 
6cc530c
9c80e2d
 
 
 
 
 
9f22b53
970edfa
 
 
 
d06517d
9c80e2d
970edfa
 
 
7545949
b8cf886
7077e08
b8cf886
7545949
6cc530c
 
b8cf886
7545949
 
 
 
970edfa
b8cf886
 
 
 
 
 
 
9c80e2d
b8cf886
 
 
 
 
 
 
 
 
970edfa
d06517d
970edfa
9c80e2d
970edfa
9f22b53
d16040b
9f22b53
 
7077e08
d06517d
9f22b53
 
 
 
d16040b
9f22b53
970edfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import subprocess
import sys
import os


def install_requirements():
    packages = [
        "numpy==1.24.3",
        "torch==2.0.1",
        "torchvision==0.15.2",
        "Pillow==9.5.0",
        "gradio==3.50.2"
    ]
    for package in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", package])


install_requirements()

import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr


class ModifiedLargeNet(nn.Module):
    def __init__(self):
        super(ModifiedLargeNet, self).__init__()
        self.name = "modified_large"
        self.conv1 = nn.Conv2d(3, 5, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(5, 10, 5)
        self.fc1 = nn.Linear(10 * 29 * 29, 32)
        self.fc2 = nn.Linear(32, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 10 * 29 * 29)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


try:
    model = ModifiedLargeNet()
    state_dict = torch.load("modified_large_net.pt", map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)
    print("Model loaded successfully")
    model.eval()
except Exception as e:
    print(f"Error loading model: {str(e)}")
    traceback.print_exc()

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def custom_transform(pil_image):
    np_image = np.array(pil_image)
    
    tensor_image = torch.from_numpy(np_image.transpose((2, 0, 1))).float()
    
    tensor_image = tensor_image / 255.0
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    tensor_image = normalize(tensor_image)
    
    return tensor_image

def process_image(image):
    if image is None:
        return None
    
    try:
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'))
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        image = image.resize((128, 128), Image.Resampling.LANCZOS)
        
        print(f"Processed image size: {image.size}")
        print(f"Processed image mode: {image.mode}")
        print(f"Image type: {type(image)}")
        
        return image
    except Exception as e:
        print(f"Error in process_image: {str(e)}")
        traceback.print_exc()
        return None

def predict(image):
    if image is None:
        return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
    
    try:
        processed_image = process_image(image)
        if processed_image is None:
            return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
        
        try:
            tensor_image = custom_transform(processed_image)
            tensor_image = tensor_image.unsqueeze(0) 
            
            print(f"Input tensor shape: {tensor_image.shape}")
            print(f"Tensor dtype: {tensor_image.dtype}")
            print(f"Tensor device: {tensor_image.device}")
            
        except Exception as e:
            print(f"Error in tensor conversion: {str(e)}")
            traceback.print_exc()
            return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
        
        try:
            with torch.no_grad():
                outputs = model(tensor_image)
                print(f"Raw outputs: {outputs}")
                
                probabilities = F.softmax(outputs, dim=1)[0].cpu().numpy()
                print(f"Probabilities: {probabilities}")
            
            classes = ["Rope", "Hammer", "Other"]
            results = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
            print(f"Final results: {results}")
            return results
            
        except Exception as e:
            print(f"Error in prediction: {str(e)}")
            traceback.print_exc()
            return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
            
    except Exception as e:
        print(f"Prediction error: {str(e)}")
        traceback.print_exc()
        return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}

# Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"), 
    outputs=gr.Label(num_top_classes=3),
    title="Mechanical Tools Classifier",
    description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()