SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
8.26 kB
import os
import sys
import torch
import argparse
import numpy as np
import os.path as osp
import re
from imageio import imread, imwrite
import torch.nn.functional as F
sys.path.append('.')
from utils.flow_generation.liteflownet.run import estimate
def read(file):
if file.endswith('.float3'): return readFloat(file)
elif file.endswith('.flo'): return readFlow(file)
elif file.endswith('.ppm'): return readImage(file)
elif file.endswith('.pgm'): return readImage(file)
elif file.endswith('.png'): return readImage(file)
elif file.endswith('.jpg'): return readImage(file)
elif file.endswith('.pfm'): return readPFM(file)[0]
else: raise Exception('don\'t know how to read %s' % file)
def write(file, data):
if file.endswith('.float3'): return writeFloat(file, data)
elif file.endswith('.flo'): return writeFlow(file, data)
elif file.endswith('.ppm'): return writeImage(file, data)
elif file.endswith('.pgm'): return writeImage(file, data)
elif file.endswith('.png'): return writeImage(file, data)
elif file.endswith('.jpg'): return writeImage(file, data)
elif file.endswith('.pfm'): return writePFM(file, data)
else: raise Exception('don\'t know how to write %s' % file)
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header.decode("ascii") == 'PF':
color = True
elif header.decode("ascii") == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0:
endian = '<'
scale = -scale
else:
endian = '>'
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def writePFM(file, image, scale=1):
file = open(file, 'wb')
color = None
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3:
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n' if color else 'Pf\n'.encode())
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write('%f\n'.encode() % scale)
image.tofile(file)
def readFlow(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
return readPFM(name)[0][:,:,0:2]
f = open(name, 'rb')
header = f.read(4)
if header.decode("utf-8") != 'PIEH':
raise Exception('Flow file header does not contain PIEH')
width = np.fromfile(f, np.int32, 1).squeeze()
height = np.fromfile(f, np.int32, 1).squeeze()
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
return flow.astype(np.float32)
def readImage(name):
if name.endswith('.pfm') or name.endswith('.PFM'):
data = readPFM(name)[0]
if len(data.shape)==3:
return data[:,:,0:3]
else:
return data
return imread(name)
def writeImage(name, data):
if name.endswith('.pfm') or name.endswith('.PFM'):
return writePFM(name, data, 1)
return imwrite(name, data)
def writeFlow(name, flow):
f = open(name, 'wb')
f.write('PIEH'.encode('utf-8'))
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
flow = flow.astype(np.float32)
flow.tofile(f)
def readFloat(name):
f = open(name, 'rb')
if(f.readline().decode("utf-8")) != 'float\n':
raise Exception('float file %s did not contain <float> keyword' % name)
dim = int(f.readline())
dims = []
count = 1
for i in range(0, dim):
d = int(f.readline())
dims.append(d)
count *= d
dims = list(reversed(dims))
data = np.fromfile(f, np.float32, count).reshape(dims)
if dim > 2:
data = np.transpose(data, (2, 1, 0))
data = np.transpose(data, (1, 0, 2))
return data
def writeFloat(name, data):
f = open(name, 'wb')
dim=len(data.shape)
if dim>3:
raise Exception('bad float file dimension: %d' % dim)
f.write(('float\n').encode('ascii'))
f.write(('%d\n' % dim).encode('ascii'))
if dim == 1:
f.write(('%d\n' % data.shape[0]).encode('ascii'))
else:
f.write(('%d\n' % data.shape[1]).encode('ascii'))
f.write(('%d\n' % data.shape[0]).encode('ascii'))
for i in range(2, dim):
f.write(('%d\n' % data.shape[i]).encode('ascii'))
data = data.astype(np.float32)
if dim==2:
data.tofile(f)
else:
np.transpose(data, (2, 0, 1)).tofile(f)
def check_dim_and_resize(tensor_list):
shape_list = []
for t in tensor_list:
shape_list.append(t.shape[2:])
if len(set(shape_list)) > 1:
desired_shape = shape_list[0]
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
resize_tensor_list = []
for t in tensor_list:
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
tensor_list = resize_tensor_list
return tensor_list
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'Flow generation',
)
parser.add_argument('-r', '--root', default='../data/vimeo_triplet')
args = parser.parse_args()
vimeo90k_dir = args.root
vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences')
vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow')
def pred_flow(img1, img2):
img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0
img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0
flow = estimate(img1, img2)
flow = flow.permute(1, 2, 0).cpu().numpy()
return flow
print('Built Flow Path')
if not osp.exists(vimeo90k_flow_dir):
os.makedirs(vimeo90k_flow_dir)
for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path)
vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path)
if not osp.exists(vimeo90k_flow_path_dir):
os.mkdir(vimeo90k_flow_path_dir)
for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id)
if not osp.exists(vimeo90k_flow_id_dir):
os.mkdir(vimeo90k_flow_id_dir)
for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)):
vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path)
vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path)
for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)):
vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id)
vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id)
img0_path = vimeo90k_sequences_id_dir + '/im1.png'
imgt_path = vimeo90k_sequences_id_dir + '/im2.png'
img1_path = vimeo90k_sequences_id_dir + '/im3.png'
flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo'
flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo'
img0 = read(img0_path)
imgt = read(imgt_path)
img1 = read(img1_path)
flow_t0 = pred_flow(imgt, img0)
flow_t1 = pred_flow(imgt, img1)
write(flow_t0_path, flow_t0)
write(flow_t1_path, flow_t1)
print('Written Sequences {}'.format(sequences_path))