samtrack / aot /dataloaders /eval_datasets.py
aikenml's picture
Upload folder using huggingface_hub
c985ba4
from __future__ import division
import os
import shutil
import json
import cv2
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from utils.image import _palette
class VOSTest(Dataset):
def __init__(self,
image_root,
label_root,
seq_name,
images,
labels,
rgb=True,
transform=None,
single_obj=False,
resolution=None):
self.image_root = image_root
self.label_root = label_root
self.seq_name = seq_name
self.images = images
self.labels = labels
self.obj_num = 1
self.num_frame = len(self.images)
self.transform = transform
self.rgb = rgb
self.single_obj = single_obj
self.resolution = resolution
self.obj_nums = []
self.obj_indices = []
curr_objs = [0]
for img_name in self.images:
self.obj_nums.append(len(curr_objs) - 1)
current_label_name = img_name.split('.')[0] + '.png'
if current_label_name in self.labels:
current_label = self.read_label(current_label_name)
curr_obj = list(np.unique(current_label))
for obj_idx in curr_obj:
if obj_idx not in curr_objs:
curr_objs.append(obj_idx)
self.obj_indices.append(curr_objs.copy())
self.obj_nums[0] = self.obj_nums[1]
def __len__(self):
return len(self.images)
def read_image(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.image_root, self.seq_name, img_name)
img = cv2.imread(img_path)
img = np.array(img, dtype=np.float32)
if self.rgb:
img = img[:, :, [2, 1, 0]]
return img
def read_label(self, label_name, squeeze_idx=None):
label_path = os.path.join(self.label_root, self.seq_name, label_name)
label = Image.open(label_path)
label = np.array(label, dtype=np.uint8)
if self.single_obj:
label = (label > 0).astype(np.uint8)
elif squeeze_idx is not None:
squeezed_label = label * 0
for idx in range(len(squeeze_idx)):
obj_id = squeeze_idx[idx]
if obj_id == 0:
continue
mask = label == obj_id
squeezed_label += (mask * idx).astype(np.uint8)
label = squeezed_label
return label
def __getitem__(self, idx):
img_name = self.images[idx]
current_img = self.read_image(idx)
height, width, channels = current_img.shape
if self.resolution is not None:
width = int(np.ceil(
float(width) * self.resolution / float(height)))
height = int(self.resolution)
current_label_name = img_name.split('.')[0] + '.png'
obj_num = self.obj_nums[idx]
obj_idx = self.obj_indices[idx]
if current_label_name in self.labels:
current_label = self.read_label(current_label_name, obj_idx)
sample = {
'current_img': current_img,
'current_label': current_label
}
else:
sample = {'current_img': current_img}
sample['meta'] = {
'seq_name': self.seq_name,
'frame_num': self.num_frame,
'obj_num': obj_num,
'current_name': img_name,
'height': height,
'width': width,
'flip': False,
'obj_idx': obj_idx
}
if self.transform is not None:
sample = self.transform(sample)
return sample
class YOUTUBEVOS_Test(object):
def __init__(self,
root='./datasets/YTB',
year=2018,
split='val',
transform=None,
rgb=True,
result_root=None):
if split == 'val':
split = 'valid'
root = os.path.join(root, str(year), split)
self.db_root_dir = root
self.result_root = result_root
self.rgb = rgb
self.transform = transform
self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json')
self._check_preprocess()
self.seqs = list(self.ann_f.keys())
self.image_root = os.path.join(root, 'JPEGImages')
self.label_root = os.path.join(root, 'Annotations')
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
seq_name = self.seqs[idx]
data = self.ann_f[seq_name]['objects']
obj_names = list(data.keys())
images = []
labels = []
for obj_n in obj_names:
images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))
labels.append(data[obj_n]["frames"][0] + '.png')
images = np.sort(np.unique(images))
labels = np.sort(np.unique(labels))
try:
if not os.path.isfile(
os.path.join(self.result_root, seq_name, labels[0])):
if not os.path.exists(os.path.join(self.result_root,
seq_name)):
os.makedirs(os.path.join(self.result_root, seq_name))
shutil.copy(
os.path.join(self.label_root, seq_name, labels[0]),
os.path.join(self.result_root, seq_name, labels[0]))
except Exception as inst:
print(inst)
print('Failed to create a result folder for sequence {}.'.format(
seq_name))
seq_dataset = VOSTest(self.image_root,
self.label_root,
seq_name,
images,
labels,
transform=self.transform,
rgb=self.rgb)
return seq_dataset
def _check_preprocess(self):
_seq_list_file = self.seq_list_file
if not os.path.isfile(_seq_list_file):
print(_seq_list_file)
return False
else:
self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
return True
class YOUTUBEVOS_DenseTest(object):
def __init__(self,
root='./datasets/YTB',
year=2018,
split='val',
transform=None,
rgb=True,
result_root=None):
if split == 'val':
split = 'valid'
root_sparse = os.path.join(root, str(year), split)
root_dense = root_sparse + '_all_frames'
self.db_root_dir = root_dense
self.result_root = result_root
self.rgb = rgb
self.transform = transform
self.seq_list_file = os.path.join(root_sparse, 'meta.json')
self._check_preprocess()
self.seqs = list(self.ann_f.keys())
self.image_root = os.path.join(root_dense, 'JPEGImages')
self.label_root = os.path.join(root_sparse, 'Annotations')
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
seq_name = self.seqs[idx]
data = self.ann_f[seq_name]['objects']
obj_names = list(data.keys())
images_sparse = []
for obj_n in obj_names:
images_sparse += map(lambda x: x + '.jpg',
list(data[obj_n]["frames"]))
images_sparse = np.sort(np.unique(images_sparse))
images = np.sort(
list(os.listdir(os.path.join(self.image_root, seq_name))))
start_img = images_sparse[0]
end_img = images_sparse[-1]
for start_idx in range(len(images)):
if start_img in images[start_idx]:
break
for end_idx in range(len(images))[::-1]:
if end_img in images[end_idx]:
break
images = images[start_idx:(end_idx + 1)]
labels = np.sort(
list(os.listdir(os.path.join(self.label_root, seq_name))))
try:
if not os.path.isfile(
os.path.join(self.result_root, seq_name, labels[0])):
if not os.path.exists(os.path.join(self.result_root,
seq_name)):
os.makedirs(os.path.join(self.result_root, seq_name))
shutil.copy(
os.path.join(self.label_root, seq_name, labels[0]),
os.path.join(self.result_root, seq_name, labels[0]))
except Exception as inst:
print(inst)
print('Failed to create a result folder for sequence {}.'.format(
seq_name))
seq_dataset = VOSTest(self.image_root,
self.label_root,
seq_name,
images,
labels,
transform=self.transform,
rgb=self.rgb)
seq_dataset.images_sparse = images_sparse
return seq_dataset
def _check_preprocess(self):
_seq_list_file = self.seq_list_file
if not os.path.isfile(_seq_list_file):
print(_seq_list_file)
return False
else:
self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
return True
class DAVIS_Test(object):
def __init__(self,
split=['val'],
root='./DAVIS',
year=2017,
transform=None,
rgb=True,
full_resolution=False,
result_root=None):
self.transform = transform
self.rgb = rgb
self.result_root = result_root
if year == 2016:
self.single_obj = True
else:
self.single_obj = False
if full_resolution:
resolution = 'Full-Resolution'
else:
resolution = '480p'
self.image_root = os.path.join(root, 'JPEGImages', resolution)
self.label_root = os.path.join(root, 'Annotations', resolution)
seq_names = []
for spt in split:
if spt == 'test':
spt = 'test-dev'
with open(os.path.join(root, 'ImageSets', str(year),
spt + '.txt')) as f:
seqs_tmp = f.readlines()
seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
seq_names.extend(seqs_tmp)
self.seqs = list(np.unique(seq_names))
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
seq_name = self.seqs[idx]
images = list(
np.sort(os.listdir(os.path.join(self.image_root, seq_name))))
labels = [images[0].replace('jpg', 'png')]
if not os.path.isfile(
os.path.join(self.result_root, seq_name, labels[0])):
seq_result_folder = os.path.join(self.result_root, seq_name)
try:
if not os.path.exists(seq_result_folder):
os.makedirs(seq_result_folder)
except Exception as inst:
print(inst)
print(
'Failed to create a result folder for sequence {}.'.format(
seq_name))
source_label_path = os.path.join(self.label_root, seq_name,
labels[0])
result_label_path = os.path.join(self.result_root, seq_name,
labels[0])
if self.single_obj:
label = Image.open(source_label_path)
label = np.array(label, dtype=np.uint8)
label = (label > 0).astype(np.uint8)
label = Image.fromarray(label).convert('P')
label.putpalette(_palette)
label.save(result_label_path)
else:
shutil.copy(source_label_path, result_label_path)
seq_dataset = VOSTest(self.image_root,
self.label_root,
seq_name,
images,
labels,
transform=self.transform,
rgb=self.rgb,
single_obj=self.single_obj,
resolution=480)
return seq_dataset
class _EVAL_TEST(Dataset):
def __init__(self, transform, seq_name):
self.seq_name = seq_name
self.num_frame = 10
self.transform = transform
def __len__(self):
return self.num_frame
def __getitem__(self, idx):
current_frame_obj_num = 2
height = 400
width = 400
img_name = 'test{}.jpg'.format(idx)
current_img = np.zeros((height, width, 3)).astype(np.float32)
if idx == 0:
current_label = (current_frame_obj_num * np.ones(
(height, width))).astype(np.uint8)
sample = {
'current_img': current_img,
'current_label': current_label
}
else:
sample = {'current_img': current_img}
sample['meta'] = {
'seq_name': self.seq_name,
'frame_num': self.num_frame,
'obj_num': current_frame_obj_num,
'current_name': img_name,
'height': height,
'width': width,
'flip': False
}
if self.transform is not None:
sample = self.transform(sample)
return sample
class EVAL_TEST(object):
def __init__(self, transform=None, result_root=None):
self.transform = transform
self.result_root = result_root
self.seqs = ['test1', 'test2', 'test3']
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
seq_name = self.seqs[idx]
if not os.path.exists(os.path.join(self.result_root, seq_name)):
os.makedirs(os.path.join(self.result_root, seq_name))
seq_dataset = _EVAL_TEST(self.transform, seq_name)
return seq_dataset