Spaces:
Configuration error

englert commited on
Commit
6838da7
·
1 Parent(s): 7c9fcd5

rollback to original

Browse files
Files changed (1) hide show
  1. app.py +79 -78
app.py CHANGED
@@ -1,85 +1,86 @@
1
- # import os
2
- # import shutil
3
- # import zipfile
4
- # from os.path import join, isfile, basename
5
- #
6
- # import cv2
7
- # import numpy as np
8
  import gradio as gr
9
- # import torch
 
 
 
 
 
10
 
11
- # from resnet50 import resnet18
12
- # from sampling_util import furthest_neighbours
13
- # from video_reader import video_reader
14
- #
15
- # model = resnet18(
16
- # output_dim=0,
17
- # nmb_prototypes=0,
18
- # eval_mode=True,
19
- # hidden_mlp=0,
20
- # normalize=False)
21
- # model.load_state_dict(torch.load("model.pt"))
22
- # model.eval()
23
- # avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
24
 
25
 
26
  def predict(input_file, downsample_size):
27
- # downsample_size = int(downsample_size)
28
- # base_directory = os.getcwd()
29
- # selected_directory = os.path.join(base_directory, "selected_images")
30
- # if os.path.isdir(selected_directory):
31
- # shutil.rmtree(selected_directory)
32
- # os.mkdir(selected_directory)
33
- #
34
- # file_name = (input_file.split('/')[-1]).split('.')[-1]
35
- # zip_path = os.path.join(selected_directory, file_name + ".zip")
36
- #
37
- # mean = np.asarray([0.3156024, 0.33569682, 0.34337464], dtype=np.float32)
38
- # std = np.asarray([0.16568947, 0.17827448, 0.18925823], dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # img_vecs = []
41
- # with torch.no_grad():
42
- # for fp_i, file_path in enumerate([input_file]):
43
- # for i, in_img in enumerate(video_reader(file_path,
44
- # targetFPS=9,
45
- # targetWidth=100,
46
- # to_rgb=True)):
47
- # in_img = (in_img.astype(np.float32) / 255.)
48
- # in_img = (in_img - mean) / std
49
- # in_img = np.expand_dims(in_img, 0)
50
- # in_img = np.transpose(in_img, (0, 3, 1, 2))
51
- # in_img = torch.from_numpy(in_img).float()
52
- # encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
53
- # img_vecs += [encoded]
54
- # img_vecs = np.asarray(img_vecs)
55
- # print("images encoded")
56
- # rv_indices, _ = furthest_neighbours(
57
- # x=img_vecs,
58
- # downsample_size=downsample_size,
59
- # seed=0)
60
- # indices = np.zeros((img_vecs.shape[0],))
61
- # indices[np.asarray(rv_indices)] = 1
62
- # print("images selected")
63
 
64
- # global_ctr = 0
65
- # for fp_i, file_path in enumerate([input_file]):
66
- # for i, img in enumerate(video_reader(file_path,
67
- # targetFPS=9,
68
- # targetWidth=None,
69
- # to_rgb=False)):
70
- # if indices[global_ctr] == 1:
71
- # cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img)
72
- # global_ctr += 1
73
- # print("selected images extracted")
74
- #
75
- # all_selected_imgs_path = [join(selected_directory, f) for f in os.listdir(selected_directory) if
76
- # isfile(join(selected_directory, f))]
77
 
78
- # zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
79
- # # for i, f in enumerate(all_selected_imgs_path):
80
- # # zipf.write(f, basename(f))
81
- # zipf.close()
82
- # print("selected images zipped")
83
 
84
  return input_file
85
 
@@ -89,9 +90,9 @@ demo = gr.Interface(
89
  title="Frame selection by visual difference",
90
  description="",
91
  fn=predict,
92
- inputs=[gr.inputs.Video(label="Upload Video File"),
93
- gr.inputs.Number(label="Downsample size")],
94
- outputs=gr.outputs.File(label="Zip"),
95
  )
96
 
97
- demo.launch()
 
1
+ import os
2
+ import shutil
3
+ import zipfile
4
+ from os.path import join, isfile, basename
5
+
6
+ import cv2
7
+ import numpy as np
8
  import gradio as gr
9
+ from gradio.components import Video, Number, File
10
+ import torch
11
+
12
+ from resnet50 import resnet18
13
+ from sampling_util import furthest_neighbours
14
+ from video_reader import video_reader
15
 
16
+ model = resnet18(
17
+ output_dim=0,
18
+ nmb_prototypes=0,
19
+ eval_mode=True,
20
+ hidden_mlp=0,
21
+ normalize=False)
22
+ model.load_state_dict(torch.load("model.pt"))
23
+ model.eval()
24
+ avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
 
 
 
 
25
 
26
 
27
  def predict(input_file, downsample_size):
28
+ downsample_size = int(downsample_size)
29
+ base_directory = os.getcwd()
30
+ selected_directory = os.path.join(base_directory, "selected_images")
31
+ if os.path.isdir(selected_directory):
32
+ shutil.rmtree(selected_directory)
33
+ os.mkdir(selected_directory)
34
+
35
+ file_name = (input_file.split('/')[-1]).split('.')[-1]
36
+ zip_path = os.path.join(selected_directory, file_name + ".zip")
37
+
38
+ mean = np.asarray([0.3156024, 0.33569682, 0.34337464], dtype=np.float32)
39
+ std = np.asarray([0.16568947, 0.17827448, 0.18925823], dtype=np.float32)
40
+
41
+ img_vecs = []
42
+ with torch.no_grad():
43
+ for fp_i, file_path in enumerate([input_file]):
44
+ for i, in_img in enumerate(video_reader(file_path,
45
+ targetFPS=9,
46
+ targetWidth=100,
47
+ to_rgb=True)):
48
+ in_img = (in_img.astype(np.float32) / 255.)
49
+ in_img = (in_img - mean) / std
50
+ in_img = np.expand_dims(in_img, 0)
51
+ in_img = np.transpose(in_img, (0, 3, 1, 2))
52
+ in_img = torch.from_numpy(in_img).float()
53
+ encoded = avg_pool(model(in_img))[0, :, 0, 0].cpu().numpy()
54
+ img_vecs += [encoded]
55
+ img_vecs = np.asarray(img_vecs)
56
+ print("images encoded")
57
+ rv_indices, _ = furthest_neighbours(
58
+ x=img_vecs,
59
+ downsample_size=downsample_size,
60
+ seed=0)
61
+ indices = np.zeros((img_vecs.shape[0],))
62
+ indices[np.asarray(rv_indices)] = 1
63
+ print("images selected")
64
 
65
+ global_ctr = 0
66
+ for fp_i, file_path in enumerate([input_file]):
67
+ for i, img in enumerate(video_reader(file_path,
68
+ targetFPS=9,
69
+ targetWidth=None,
70
+ to_rgb=False)):
71
+ if indices[global_ctr] == 1:
72
+ cv2.imwrite(join(selected_directory, str(global_ctr) + ".jpg"), img)
73
+ global_ctr += 1
74
+ print("selected images extracted")
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ all_selected_imgs_path = [join(selected_directory, f) for f in os.listdir(selected_directory) if
77
+ isfile(join(selected_directory, f))]
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ zipf = zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED)
80
+ for i, f in enumerate(all_selected_imgs_path):
81
+ zipf.write(f, basename(f))
82
+ zipf.close()
83
+ print("selected images zipped")
84
 
85
  return input_file
86
 
 
90
  title="Frame selection by visual difference",
91
  description="",
92
  fn=predict,
93
+ inputs=[Video(label="Upload Video File"),
94
+ Number(label="Downsample size")],
95
+ outputs=File(label="Zip"),
96
  )
97
 
98
+ demo.launch(enable_queue=True)