StoneSeller commited on
Commit
831f00d
·
verified ·
1 Parent(s): 1394155

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +54 -0
  2. modified_large_net.pt +3 -0
  3. requriements.txt +4 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+
8
+ class ModifiedLargeNet(nn.Module):
9
+ def __init__(self):
10
+ super(ModifiedLargeNet, self).__init__()
11
+ self.name = "modified_large"
12
+ self.fc1 = nn.Linear(128 * 128 * 3, 256)
13
+ self.fc2 = nn.Linear(256, 128)
14
+ self.fc3 = nn.Linear(128, 3) # 3 classes: Rope, Hammer, Other
15
+
16
+ def forward(self, x):
17
+ x = x.view(-1, 128 * 128 * 3)
18
+ x = torch.relu(self.fc1(x))
19
+ x = torch.relu(self.fc2(x))
20
+ x = self.fc3(x)
21
+ return x
22
+
23
+
24
+ model = ModifiedLargeNet()
25
+ model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
26
+ model.eval()
27
+
28
+
29
+ transform = transforms.Compose([
30
+ transforms.Resize((128, 128)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
33
+ ])
34
+
35
+
36
+ def predict(image):
37
+
38
+ image = transform(image).unsqueeze(0)
39
+ with torch.no_grad():
40
+ outputs = model(image)
41
+ probabilities = torch.softmax(outputs, dim=1).numpy()[0]
42
+ classes = ["Rope", "Hammer", "Other"]
43
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
44
+
45
+ interface = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type="pil"),
48
+ outputs=gr.Label(num_top_classes=3),
49
+ title="Mechanical Tools Classifier",
50
+ description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ interface.launch()
modified_large_net.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10ca86bf582a3820406346d57fa09083279fbb4aa1862d44b966144a556d228c
3
+ size 1086788
requriements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ Pillow