Spaces:
Running
Running
from torch.utils.data import Dataset | |
import os | |
from torchvision.datasets.folder import default_loader | |
import torchvision.transforms as T | |
import torch | |
import numpy as np | |
from PIL import Image | |
class CommonDataset(Dataset): | |
def __init__(self, images_path, labels_path, x_transform, y_transform): | |
self.imgs_path = images_path | |
self.labels_path = labels_path | |
# for p in os.listdir(os.path.join(image_dir)): | |
# p = os.path.join(dataset_project_dir, 'images', p) | |
# if not p.endswith('png'): | |
# continue | |
# self.imgs_path += [p] | |
# self.labels_path += [p.replace('images', 'labels_gt')] | |
# self.x_transform = T.Compose( | |
# [ | |
# T.Resize((224, 224)), | |
# T.ToTensor(), | |
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
# ] | |
# ) | |
# self.y_transform = T.Compose( | |
# [ | |
# T.Resize((224, 224)), | |
# T.Lambda(lambda x: torch.from_numpy(np.array(x)).long()) | |
# ] | |
# ) | |
self.x_transform = x_transform | |
self.y_transform = y_transform | |
def __len__(self): | |
return len(self.imgs_path) | |
def __getitem__(self, idx): | |
x_path = os.path.join(self.imgs_path[idx]) | |
y_path = os.path.join(self.labels_path[idx]) | |
x = default_loader(x_path) | |
# y = default_loader(y_path) | |
y = Image.open(y_path).convert('L') | |
x = self.x_transform(x) | |
y = self.y_transform(y) | |
return x, y | |