StoneSeller commited on
Commit
8342f3a
·
verified ·
1 Parent(s): 9f22b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -22,15 +22,19 @@ class ModifiedLargeNet(nn.Module):
22
  def __init__(self):
23
  super(ModifiedLargeNet, self).__init__()
24
  self.name = "modified_large"
25
- self.fc1 = nn.Linear(128 * 128 * 3, 256)
26
- self.fc2 = nn.Linear(256, 128)
27
- self.fc3 = nn.Linear(128, 3) # 3 classes: Rope, Hammer, Other
 
 
28
 
29
  def forward(self, x):
30
- x = x.view(-1, 128 * 128 * 3)
31
- x = torch.relu(self.fc1(x))
32
- x = torch.relu(self.fc2(x))
33
- x = self.fc3(x)
 
 
34
  return x
35
 
36
 
 
22
  def __init__(self):
23
  super(ModifiedLargeNet, self).__init__()
24
  self.name = "modified_large"
25
+ self.conv1 = nn.Conv2d(3, 5, 5)
26
+ self.pool = nn.MaxPool2d(2, 2)
27
+ self.conv2 = nn.Conv2d(5, 10, 5)
28
+ self.fc1 = nn.Linear(10 * 29 * 29, 32)
29
+ self.fc2 = nn.Linear(32, 3) # classify into "Rope"/"Hammer"/"others"
30
 
31
  def forward(self, x):
32
+ x = self.pool(F.relu(self.conv1(x)))
33
+ x = self.pool(F.relu(self.conv2(x)))
34
+ x = x.view(-1, 10 * 29 * 29)
35
+ x = F.relu(self.fc1(x))
36
+ x = self.fc2(x)
37
+ x = x.squeeze(1) # Flatten to [batch_size]
38
  return x
39
 
40