StoneSeller commited on
Commit
9f22b53
·
verified ·
1 Parent(s): 397c575

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -54
app.py CHANGED
@@ -1,54 +1,67 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.transforms as transforms
4
- from PIL import Image
5
- import gradio as gr
6
-
7
-
8
- class ModifiedLargeNet(nn.Module):
9
- def __init__(self):
10
- super(ModifiedLargeNet, self).__init__()
11
- self.name = "modified_large"
12
- self.fc1 = nn.Linear(128 * 128 * 3, 256)
13
- self.fc2 = nn.Linear(256, 128)
14
- self.fc3 = nn.Linear(128, 3) # 3 classes: Rope, Hammer, Other
15
-
16
- def forward(self, x):
17
- x = x.view(-1, 128 * 128 * 3)
18
- x = torch.relu(self.fc1(x))
19
- x = torch.relu(self.fc2(x))
20
- x = self.fc3(x)
21
- return x
22
-
23
-
24
- model = ModifiedLargeNet()
25
- model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
26
- model.eval()
27
-
28
-
29
- transform = transforms.Compose([
30
- transforms.Resize((128, 128)),
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
33
- ])
34
-
35
-
36
- def predict(image):
37
-
38
- image = transform(image).unsqueeze(0)
39
- with torch.no_grad():
40
- outputs = model(image)
41
- probabilities = torch.softmax(outputs, dim=1).numpy()[0]
42
- classes = ["Rope", "Hammer", "Other"]
43
- return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
44
-
45
- interface = gr.Interface(
46
- fn=predict,
47
- inputs=gr.Image(type="pil"),
48
- outputs=gr.Label(num_top_classes=3),
49
- title="Mechanical Tools Classifier",
50
- description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
51
- )
52
-
53
- if __name__ == "__main__":
54
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import sys
4
+
5
+ try:
6
+ import torch
7
+ except ImportError:
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install",
9
+ "torch==2.0.1+cpu",
10
+ "torchvision==0.15.2+cpu",
11
+ "-f", "https://download.pytorch.org/whl/torch_stable.html"])
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torchvision.transforms as transforms
16
+ from PIL import Image
17
+ import gradio as gr
18
+
19
+
20
+
21
+ class ModifiedLargeNet(nn.Module):
22
+ def __init__(self):
23
+ super(ModifiedLargeNet, self).__init__()
24
+ self.name = "modified_large"
25
+ self.fc1 = nn.Linear(128 * 128 * 3, 256)
26
+ self.fc2 = nn.Linear(256, 128)
27
+ self.fc3 = nn.Linear(128, 3) # 3 classes: Rope, Hammer, Other
28
+
29
+ def forward(self, x):
30
+ x = x.view(-1, 128 * 128 * 3)
31
+ x = torch.relu(self.fc1(x))
32
+ x = torch.relu(self.fc2(x))
33
+ x = self.fc3(x)
34
+ return x
35
+
36
+
37
+ model = ModifiedLargeNet()
38
+ model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
39
+ model.eval()
40
+
41
+
42
+ transform = transforms.Compose([
43
+ transforms.Resize((128, 128)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
46
+ ])
47
+
48
+
49
+ def predict(image):
50
+
51
+ image = transform(image).unsqueeze(0)
52
+ with torch.no_grad():
53
+ outputs = model(image)
54
+ probabilities = torch.softmax(outputs, dim=1).numpy()[0]
55
+ classes = ["Rope", "Hammer", "Other"]
56
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
57
+
58
+ interface = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Image(type="pil"),
61
+ outputs=gr.Label(num_top_classes=3),
62
+ title="Mechanical Tools Classifier",
63
+ description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ interface.launch()