|
import torch |
|
import os.path as osp |
|
import glob |
|
import cv2 |
|
import numpy as np |
|
from torch.utils import data |
|
|
|
|
|
class PathImages(data.Dataset): |
|
def __init__(self, root_path): |
|
self.images_files = glob.glob(osp.join(root_path, '*.jpg')) |
|
self.images_files.sort() |
|
|
|
def __len__(self): |
|
return len(self.images_files) |
|
|
|
def __getitem__(self, index): |
|
return self.image2tensor(self.images_files[index]) |
|
|
|
@staticmethod |
|
def image2tensor(image_file): |
|
|
|
img = cv2.imread(image_file, cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype('uint8') |
|
return torch.tensor(np.transpose(img, (2, 0, 1))).float() |
|
|