kimjy0411 commited on
Commit
11a4bd1
·
verified ·
1 Parent(s): bef9156

Upload src/utils/frame_interpolation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/utils/frame_interpolation.py +69 -0
src/utils/frame_interpolation.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/dajes/frame-interpolation-pytorch
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import bisect
7
+ import shutil
8
+ import pdb
9
+ from tqdm import tqdm
10
+
11
+ def init_frame_interpolation_model():
12
+ print("Initializing frame interpolation model")
13
+ checkpoint_name = os.path.join("./pretrained_model/film_net_fp16.pt")
14
+
15
+ model = torch.jit.load(checkpoint_name, map_location='cpu')
16
+ model.eval()
17
+ model = model.half()
18
+ model = model.to(device="cuda")
19
+ return model
20
+
21
+
22
+ def batch_images_interpolation_tool(input_tensor, model, inter_frames=1):
23
+
24
+ video_tensor = []
25
+ frame_num = input_tensor.shape[2] # bs, channel, frame, height, width
26
+
27
+ for idx in tqdm(range(frame_num-1)):
28
+ image1 = input_tensor[:,:,idx]
29
+ image2 = input_tensor[:,:,idx+1]
30
+
31
+ results = [image1, image2]
32
+
33
+ inter_frames = int(inter_frames)
34
+ idxes = [0, inter_frames + 1]
35
+ remains = list(range(1, inter_frames + 1))
36
+
37
+ splits = torch.linspace(0, 1, inter_frames + 2)
38
+
39
+ for _ in range(len(remains)):
40
+ starts = splits[idxes[:-1]]
41
+ ends = splits[idxes[1:]]
42
+ distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
43
+ matrix = torch.argmin(distances).item()
44
+ start_i, step = np.unravel_index(matrix, distances.shape)
45
+ end_i = start_i + 1
46
+
47
+ x0 = results[start_i]
48
+ x1 = results[end_i]
49
+
50
+ x0 = x0.half()
51
+ x1 = x1.half()
52
+ x0 = x0.cuda()
53
+ x1 = x1.cuda()
54
+
55
+ dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
56
+
57
+ with torch.no_grad():
58
+ prediction = model(x0, x1, dt)
59
+ insert_position = bisect.bisect_left(idxes, remains[step])
60
+ idxes.insert(insert_position, remains[step])
61
+ results.insert(insert_position, prediction.clamp(0, 1).cpu().float())
62
+ del remains[step]
63
+
64
+ for sub_idx in range(len(results)-1):
65
+ video_tensor.append(results[sub_idx].unsqueeze(2))
66
+
67
+ video_tensor.append(input_tensor[:,:,-1].unsqueeze(2))
68
+ video_tensor = torch.cat(video_tensor, dim=2)
69
+ return video_tensor