Spaces:
Configuration error

englert commited on
Commit
e059c3c
·
1 Parent(s): bbe0214

update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -35
app.py CHANGED
@@ -1,49 +1,56 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
- import requests
4
- from torchvision import transforms
5
 
6
  from sampling_util import furthest_neighbours
7
  from video_reader import video_reader
8
 
9
  model = torch.load("model").eval()
10
- avg_pool = nn.AdaptiveAvgPool2d((1, 1))
 
11
 
 
 
12
 
13
- def predict(input_file):
14
- base_directory = os.getcwd()
15
  selected_directory = os.path.join(base_directory, "selected_images")
16
  if os.path.isdir(selected_directory):
17
  shutil.rmtree(selected_directory)
18
  os.mkdir(selected_directory)
19
-
20
  zip_path = os.path.join(input_file.split('/')[-1][:-4] + ".zip")
21
-
22
- mean = [0.3156024, 0.33569682, 0.34337464]
23
- std = [0.16568947, 0.17827448, 0.18925823]
24
-
25
  img_vecs = []
26
  with torch.no_grad():
27
- for fp_i, file_path in enumerate([input_file]):
28
- for i, in_img in enumerate(video_reader(file_path,
29
- targetFPS=9,
30
- targetWidth=100,
31
- to_rgb=True)):
32
- in_img = (in_img.astype(np.float32) / 255.)
33
- in_img = (in_img - mean) / std
34
- in_img = np.transpose(in_img, (0, 3, 1, 2))
35
- in_img = torch.from_numpy(in_img)
36
- encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
37
- img_vecs += [encoded]
38
-
39
- img_vecs = np.asarray(img_vecs)
40
- rv_indices, _ = furthest_neighbours(
41
  img_vecs,
42
  downsample_size,
43
  seed=0)
44
  indices = np.zeros((img_vecs.shape[0],))
45
  indices[np.asarray(rv_indices)] = 1
46
-
47
  global_ctr = 0
48
  for fp_i, file_path in enumerate([input_file]):
49
  for i, img in enumerate(video_reader(file_path,
@@ -53,19 +60,21 @@ def predict(input_file):
53
  if indices[global_ctr] == 1:
54
  cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img)
55
  global_ctr += 1
56
-
57
- all_selected_imgs_path = [join(selected_directory, f) for f in listdir(selected_directory) if isfile(join(selected_directory, f))]
58
- if 0 < len(all_file_paths):
 
59
  zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
60
  for i, f in enumerate(all_selected_imgs_path):
61
  zipf.write(f, basename(f))
62
  zipf.close()
63
-
64
- return zip_path
 
65
 
66
  demo = gr.Interface(
67
- fn=predict,
68
- inputs=gr.inputs.Video(label="Upload Video File"),
69
- outputs=gr.outputs.File(label="Zip"))
70
-
71
- demo.launch()
 
1
+ import os
2
+ import shutil
3
+ import zipfile
4
+ from os.path import join, isfile, basename
5
+
6
+ import cv2
7
+ import numpy as np
8
  import gradio as gr
9
  import torch
 
 
10
 
11
  from sampling_util import furthest_neighbours
12
  from video_reader import video_reader
13
 
14
  model = torch.load("model").eval()
15
+ avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
16
+
17
 
18
+ def predict(input_file, downsample_size):
19
+ downsample_size = int(downsample_size)
20
 
21
+ base_directory = os.getcwd()
 
22
  selected_directory = os.path.join(base_directory, "selected_images")
23
  if os.path.isdir(selected_directory):
24
  shutil.rmtree(selected_directory)
25
  os.mkdir(selected_directory)
26
+
27
  zip_path = os.path.join(input_file.split('/')[-1][:-4] + ".zip")
28
+
29
+ mean = np.asarray([0.3156024, 0.33569682, 0.34337464])
30
+ std = np.asarray([0.16568947, 0.17827448, 0.18925823])
31
+
32
  img_vecs = []
33
  with torch.no_grad():
34
+ for fp_i, file_path in enumerate([input_file]):
35
+ for i, in_img in enumerate(video_reader(file_path,
36
+ targetFPS=9,
37
+ targetWidth=100,
38
+ to_rgb=True)):
39
+ in_img = (in_img.astype(np.float32) / 255.)
40
+ in_img = (in_img - mean) / std
41
+ in_img = np.transpose(in_img, (0, 3, 1, 2))
42
+ in_img = torch.from_numpy(in_img)
43
+ encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
44
+ img_vecs += [encoded]
45
+
46
+ img_vecs = np.asarray(img_vecs)
47
+ rv_indices, _ = furthest_neighbours(
48
  img_vecs,
49
  downsample_size,
50
  seed=0)
51
  indices = np.zeros((img_vecs.shape[0],))
52
  indices[np.asarray(rv_indices)] = 1
53
+
54
  global_ctr = 0
55
  for fp_i, file_path in enumerate([input_file]):
56
  for i, img in enumerate(video_reader(file_path,
 
60
  if indices[global_ctr] == 1:
61
  cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img)
62
  global_ctr += 1
63
+
64
+ all_selected_imgs_path = [join(selected_directory, f) for f in os.listdir(selected_directory) if
65
+ isfile(join(selected_directory, f))]
66
+
67
  zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
68
  for i, f in enumerate(all_selected_imgs_path):
69
  zipf.write(f, basename(f))
70
  zipf.close()
71
+
72
+ return zip_path
73
+
74
 
75
  demo = gr.Interface(
76
+ fn=predict,
77
+ inputs=[gr.inputs.Video(label="Upload Video File"), gr.inputs.Number(Label="Downsample size")],
78
+ outputs=gr.outputs.File(label="Zip"))
79
+
80
+ demo.launch()