biplab2008 commited on
Commit
29e2a42
·
verified ·
1 Parent(s): 91ed8d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -0
app.py CHANGED
@@ -6,6 +6,41 @@ import torch
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def parse_video(video_file):
10
  """A utility to parse the input videos.
11
  Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
@@ -46,6 +81,8 @@ def parse_video(video_file):
46
  return frames
47
 
48
  def pil_parser(video_file):
 
 
49
  # cv2 parsing
50
 
51
  dummy_frames = parse_video(video_file)
 
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
 
9
+
10
+ def load_model():
11
+ # CNN3D Layer's architecture
12
+ cnndata = CNNData(in_dim = 1,
13
+ n_f =[32,48],
14
+ kernel_size=[(5,5,5), (3,3,3)],
15
+ activations=[nn.ReLU(),nn.ReLU()],
16
+ bns = [True, True],
17
+ dropouts = [0, 0],
18
+ paddings = [(0,0,0),(0,0,0)],
19
+ strides = [(2,2,2),(2,2,2)])
20
+
21
+ # Feedforward layer's architecture
22
+ lindata = LinData(in_dim = conv3D_output_size(cnndata, [30, 256, 342]),
23
+ hidden_layers= [256,256,1],
24
+ activations=[nn.ReLU(),nn.ReLU(),None],
25
+ bns=[False,False,False],
26
+ dropouts =[0.2, 0, 0])
27
+
28
+ # combined architecture
29
+ args = NetData(cnndata, lindata)
30
+
31
+ # weight file
32
+ #weight_file = 'cnn3d_epoch_300.pt'
33
+
34
+ # CNN3D model
35
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
+ device = torch.device('cpu')
37
+ cnn3d = CNN3D(args).to(device)
38
+ #cnn3d.load_state_dict(torch.load(os.path.join(base_path,'weights',weight_file), map_location=device))
39
+ cnn3d.eval()
40
+ #print(cnn3d)
41
+
42
+ return cnn3d
43
+
44
  def parse_video(video_file):
45
  """A utility to parse the input videos.
46
  Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
 
81
  return frames
82
 
83
  def pil_parser(video_file):
84
+
85
+ model = load_model()
86
  # cv2 parsing
87
 
88
  dummy_frames = parse_video(video_file)