SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
import os
import sys
import torch
import argparse
import numpy as np
import os.path as osp
import re
from imageio import imread, imwrite
import torch.nn.functional as F
sys.path.append('.')
from utils.flow_generation.liteflownet.run import estimate
def read(file):
if file.endswith('.float3'): return readFloat(file)
elif file.endswith('.flo'): return readFlow(file)
elif file.endswith('.ppm'): return readImage(file)
elif file.endswith('.pgm'): return readImage(file)
elif file.endswith('.png'): return readImage(file)
elif file.endswith('.jpg'): return readImage(file)
elif file.endswith('.pfm'): return readPFM(file)[0]
else: raise Exception('don\'t know how to read %s' % file)
def write(file, data):
if file.endswith('.float3'): return writeFloat(file, data)
elif file.endswith('.flo'): return writeFlow(file, data)
elif file.endswith('.ppm'): return writeImage(file, data)
elif file.endswith('.pgm'): return writeImage(file, data)
elif file.endswith('.png'): return writeImage(file, data)
elif file.endswith('.jpg'): return writeImage(file, data)
elif file.endswith('.pfm'): return writePFM(file, data)
else: raise Exception('don\'t know how to write %s' % file)
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == 'PF':
color = True
elif header.decode("ascii") == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
endian = '<'
scale = -scale
else:
endian = '>'
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def writePFM(file, image, scale=1):
file = open(file, 'wb')
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n' if color else 'Pf\n'.encode())
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write('%f\n'.encode() % scale)
image.tofile(file)
def readFlow(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
return readPFM(name)[0][:,:,0:2]
f = open(name, 'rb')
header = f.read(4)
if header.decode("utf-8") != 'PIEH':
raise Exception('Flow file header does not contain PIEH')
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
data = readPFM(name)[0]
if len(data.shape)==3:
return data[:,:,0:3]
else:
return data
return imread(name)
def writeImage(name, data):
if name.endswith('.pfm') or name.endswith('.PFM'):
return writePFM(name, data, 1)
return imwrite(name, data)
def writeFlow(name, flow):
f = open(name, 'wb')
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(name):
f = open(name, 'rb')
if(f.readline().decode("utf-8")) != 'float\n':
raise Exception('float file %s did not contain <float> keyword' % name)
dim = int(f.readline())
dims = []
count = 1
for i in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(name, data):
f = open(name, 'wb')
dim=len(data.shape)
if dim>3:
raise Exception('bad float file dimension: %d' % dim)
f.write(('float\n').encode('ascii'))
f.write(('%d\n' % dim).encode('ascii'))
if dim == 1:
f.write(('%d\n' % data.shape[0]).encode('ascii'))
else:
f.write(('%d\n' % data.shape[1]).encode('ascii'))
f.write(('%d\n' % data.shape[0]).encode('ascii'))
for i in range(2, dim):
f.write(('%d\n' % data.shape[i]).encode('ascii'))
data = data.astype(np.float32)
if dim==2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)
def check_dim_and_resize(tensor_list):
shape_list = []
for t in tensor_list:
shape_list.append(t.shape[2:])
if len(set(shape_list)) > 1:
desired_shape = shape_list[0]
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
resize_tensor_list = []
for t in tensor_list:
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
tensor_list = resize_tensor_list
return tensor_list
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'Flow generation',
)
parser.add_argument('-r', '--root', default='../data/vimeo_triplet')
args = parser.parse_args()
vimeo90k_dir = args.root
vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences')
vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow')
def pred_flow(img1, img2):
img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0
img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0
flow = estimate(img1, img2)
flow = flow.permute(1, 2, 0).cpu().numpy()
return flow
print('Built Flow Path')
if not osp.exists(vimeo90k_flow_dir):
os.makedirs(vimeo90k_flow_dir)
for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path)
vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path)
if not osp.exists(vimeo90k_flow_path_dir):
os.mkdir(vimeo90k_flow_path_dir)
for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id)
if not osp.exists(vimeo90k_flow_id_dir):
os.mkdir(vimeo90k_flow_id_dir)
for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path)
vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path)
for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id)
vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id)
img0_path = vimeo90k_sequences_id_dir + '/im1.png'
imgt_path = vimeo90k_sequences_id_dir + '/im2.png'
img1_path = vimeo90k_sequences_id_dir + '/im3.png'
flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo'
flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo'
img0 = read(img0_path)
imgt = read(imgt_path)
img1 = read(img1_path)
flow_t0 = pred_flow(imgt, img0)
flow_t1 = pred_flow(imgt, img1)
write(flow_t0_path, flow_t0)
write(flow_t1_path, flow_t1)
print('Written Sequences {}'.format(sequences_path))