File size: 766 Bytes
dc7407d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from utils.datastruct import CNNData
import numpy as np
def conv3D_output_size(args, img_size):
if not isinstance(args, CNNData):
raise TypeError("input must be a ParserClass")
(cin, h , w) = img_size
# compute output shape of conv3D
for idx , chan in enumerate(args.kernel_size):
padding = args.paddings[idx]
stride = args.strides[idx]
(cin, h , w) = (np.floor((cin + 2 * padding[0] - chan[0] ) / stride[0] + 1).astype(int),
np.floor((h + 2 * padding[1] - chan[1] ) / stride[1] + 1).astype(int),
np.floor((w + 2 * padding[2] - chan[2] ) / stride[2] + 1).astype(int))
final_dim = int(args.n_f[-1] * cin * h * w)
return final_dim |