Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,8 @@ from typing import NamedTuple, List, Callable, List, Tuple, Optional
|
|
9 |
from torch import nn
|
10 |
import torch.nn.functional as F
|
11 |
|
|
|
|
|
12 |
class LinData(NamedTuple):
|
13 |
in_dim : int # input dimension
|
14 |
hidden_layers : List[int] # hidden layers including the output layer
|
@@ -31,6 +33,25 @@ class CNNData(NamedTuple):
|
|
31 |
class NetData(NamedTuple):
|
32 |
cnn3d : CNNData
|
33 |
lin : LinData
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
class CNN3D_Mike(nn.Module):
|
36 |
def __init__(self, t_dim=30, img_x=256 , img_y=342, drop_p=0, fc_hidden1=256, fc_hidden2=256):
|
|
|
9 |
from torch import nn
|
10 |
import torch.nn.functional as F
|
11 |
|
12 |
+
|
13 |
+
|
14 |
class LinData(NamedTuple):
|
15 |
in_dim : int # input dimension
|
16 |
hidden_layers : List[int] # hidden layers including the output layer
|
|
|
33 |
class NetData(NamedTuple):
|
34 |
cnn3d : CNNData
|
35 |
lin : LinData
|
36 |
+
|
37 |
+
def conv3D_output_size(args, img_size):
|
38 |
+
|
39 |
+
if not isinstance(args, CNNData):
|
40 |
+
raise TypeError("input must be a ParserClass")
|
41 |
+
|
42 |
+
(cin, h , w) = img_size
|
43 |
+
# compute output shape of conv3D
|
44 |
+
for idx , chan in enumerate(args.kernel_size):
|
45 |
+
padding = args.paddings[idx]
|
46 |
+
stride = args.strides[idx]
|
47 |
+
(cin, h , w) = (np.floor((cin + 2 * padding[0] - chan[0] ) / stride[0] + 1).astype(int),
|
48 |
+
np.floor((h + 2 * padding[1] - chan[1] ) / stride[1] + 1).astype(int),
|
49 |
+
np.floor((w + 2 * padding[2] - chan[2] ) / stride[2] + 1).astype(int))
|
50 |
+
|
51 |
+
|
52 |
+
final_dim = int(args.n_f[-1] * cin * h * w)
|
53 |
+
|
54 |
+
return final_dim
|
55 |
|
56 |
class CNN3D_Mike(nn.Module):
|
57 |
def __init__(self, t_dim=30, img_x=256 , img_y=342, drop_p=0, fc_hidden1=256, fc_hidden2=256):
|