VfiTest / prepare_extra_training_dataset.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
4.81 kB
# [START COMMAND]
# python3 -m prepare_extra_training_dataset \
# --video_path ../datasets/video1.mov \
# --frame_save_root ../datasets/frames1 \
# --tri_save_root ../datasets/frames1_triplet/sequences/0001 \
# --dup_info ../datasets/dup_video1.txt \
# --flow_save_root ../datasets/frames1_unimatch_flow/sequences \
# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2
# python3 -m prepare_extra_training_dataset \
# --video_path ../datasets/video2.mov \
# --frame_save_root ../datasets/frames2 \
# --tri_save_root ../datasets/frames2_triplet/sequences/0001 \
# --dup_info ../datasets/dup_video2.txt \
# --flow_save_root ../datasets/frames2_unimatch_flow/sequences \
# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2
# python3 -m prepare_extra_training_dataset \
# --video_path ../datasets/video3.mov \
# --frame_save_root ../datasets/frames3 \
# --tri_save_root ../datasets/frames3_triplet/sequences/0001 \
# --dup_info ../datasets/dup_video3.txt \
# --flow_save_root ../datasets/frames3_unimatch_flow/sequences \
# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2
# python3 -m prepare_extra_training_dataset \
# --video_path ../datasets/video4.mov \
# --frame_save_root ../datasets/frames4 \
# --tri_save_root ../datasets/frames4_triplet/sequences/0001 \
# --dup_info ../datasets/dup_video4.txt \
# --flow_save_root ../datasets/frames4_unimatch_flow/sequences \
# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2
import argparse
import os
import cv2
import glob
from tqdm import tqdm
def main():
parser = argparse.ArgumentParser(description="Preparing extra training dataset for UPR-Net-back inference.")
parser.add_argument('--video_path', type=str, default='../datasets/video4.mov', help="video file path")
parser.add_argument('--frame_save_root', type=str, default='../datasets/frames4', help="root to save frames")
parser.add_argument('--tri_save_root', type=str, default='../datasets/frames4_triplet/sequences', help="root to save triplets")
parser.add_argument('--dup_info', type=str, default='../datasets/dup_video4.txt', help="duplicated frames information")
parser.add_argument('--flow_save_root', type=str, default='../datasets/frames4_unimatch_flow/sequences', help="root to save UniMatch optical flows")
# UniMatch parameters
parser.add_argument('--total_gpus', type=int, default=2, help="number of CUDA GPUs to use")
parser.add_argument("--start_cuda_index", type=int, default=0, help="starting CUDA GPU index")
parser.add_argument("--cuda_index", type=int, default=0, help="CUDA GPU index")
parser.add_argument("--first_scaling", type=int, default=1, help="downsizing ratio before computing flow")
args = parser.parse_args()
os.makedirs(args.frame_save_root, exist_ok=True)
video = cv2.VideoCapture(args.video_path)
num_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
with open(args.dup_info, 'r') as f:
dup_list = [int(line.strip()) for line in f.readlines() if len(line.strip())>0]
print('SAVE Frames')
file_num = 0
for index in tqdm(range(num_frame)):
_, frame = video.read()
if index in dup_list: continue
newfile = os.path.join(args.frame_save_root, str(file_num).zfill(4)+'.png')
cv2.imwrite(newfile, frame)
file_num += 1
file_list = sorted(glob.glob(os.path.join(args.frame_save_root, '*.png')))
print('SAVE Triplets')
for i, file in enumerate(tqdm(file_list[:-1])):
if i==0: continue
if i==1:
prv = cv2.imread(file_list[i-1])
cur = cv2.imread(file)
prv_name = file_list[i-1]
cur_name = file_list[i]
nxt = cv2.imread(file_list[i+1])
nxt_name = file_list[i+1]
# SAVE
newfolder = os.path.join(args.tri_save_root, str(i).zfill(4))
os.makedirs(newfolder, exist_ok=True)
cv2.imwrite(newfolder+'/im1.png', prv)
cv2.imwrite(newfolder+'/im2.png', cur)
cv2.imwrite(newfolder+'/im3.png', nxt)
temp = '/'.join(newfolder.split('/')[-2:])
with open(os.path.join(args.tri_save_root, '..', '..', 'tri_trainlist.txt'), 'w' if i==1 else 'a') as f:
f.writelines(f'{temp}\n')
prv = cur.copy()
cur = nxt.copy()
prv_name = cur_name
cur_name = nxt_name
cwd = os.getcwd()
os.chdir('../unimatch_inference')
cmd = f'python3 -m unimatch_inference --total_gpus {args.total_gpus} --start_cuda_index {args.start_cuda_index} --cuda_index {args.cuda_index} \
--root {os.path.join(args.tri_save_root, "..")} --save_root {args.flow_save_root} --first_scaling {args.first_scaling}'
os.system(cmd)
os.chdir(cwd)
if __name__ == '__main__':
main()