|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict |
|
import json |
|
import os |
|
import pickle |
|
import zipfile |
|
import pandas as pd |
|
import numpy as np |
|
from PIL import Image, ImageFile |
|
import ast |
|
import torch |
|
import random |
|
from constants import CHEXPERT_COMPETITION_TASKS |
|
from torchvision import transforms |
|
from torchvision import datasets as t_datasets |
|
import ast |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
def pil_loader(path): |
|
|
|
with open(path, 'rb') as f: |
|
img = Image.open(f) |
|
return img.convert('RGB') |
|
|
|
|
|
class FileListDataset(torch.utils.data.Dataset): |
|
def __init__(self, images, labels, transform=None, target_transform=None): |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self.images = np.load(images) |
|
self.labels = np.load(labels) |
|
|
|
def __getitem__(self, index): |
|
img = pil_loader(self.images[index]) |
|
target = self.labels[index] |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return img, target |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
|
|
def get_downstream_dataset(catalog, name, is_train, transform): |
|
entry = catalog[name] |
|
root = entry['path'] |
|
if entry['type'] == 'imagefolder': |
|
dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']), |
|
transform=transform) |
|
elif entry['type'] == 'special': |
|
if name == 'CIFAR10': |
|
dataset = t_datasets.CIFAR10(root, train=is_train, |
|
transform=transform, download=True) |
|
elif name == 'CIFAR100': |
|
dataset = t_datasets.CIFAR100(root, train=is_train, |
|
transform=transform, download=True) |
|
elif name == 'STL10': |
|
dataset = t_datasets.STL10(root, split='train' if is_train else 'test', |
|
transform=transform, download=True) |
|
elif name == 'MNIST': |
|
dataset = t_datasets.MNIST(root, train=is_train, |
|
transform=transform, download=True) |
|
elif name == 'chexpert-5x200': |
|
dataset = ZeroShotImageDataset(['chexpert-5x200'], CHEXPERT_COMPETITION_TASKS, transform |
|
, parent_data_path=root) |
|
elif name == "radimagenet": |
|
dataset = RadImageNet(root, transform) |
|
elif name == "rsna_pneumonia": |
|
dataset = RSNA_Pneumonia(root, transform) |
|
elif name == "thyroid_us": |
|
dataset = thyroid_us_and_breast(root, transform, "thyroid_test_fold1.csv") |
|
elif name == "breast_us": |
|
dataset = thyroid_us_and_breast(root, transform, "breast_test_fold1.csv") |
|
elif name == "meniscal_mri": |
|
dataset = meniscal_mri(root, transform, "meniscus_test_fold1.csv") |
|
elif name == 'acl_mri': |
|
dataset = acl_mri(root, transform, "test_fold1.csv") |
|
elif name == 'CT_axial': |
|
dataset = CT_dataset(root, transform, "organs_axial") |
|
elif name == 'CT_coronal': |
|
dataset = CT_dataset(root, transform, "organs_coronal") |
|
elif name == 'CT_sagittal': |
|
dataset = CT_dataset(root, transform, "organs_sagittal") |
|
elif name == 'dr_regular': |
|
dataset = CT_dataset(root, transform, "dr_regular") |
|
elif name == 'dr_uwf': |
|
dataset = CT_dataset(root, transform, "dr_uwf") |
|
elif name == 'LC25000_lung': |
|
dataset = LC25000(root, transform, "lung") |
|
elif name == 'LC25000_colon': |
|
dataset = LC25000(root, transform, "colon") |
|
elif name == 'PCAM': |
|
dataset = PCAM(root, transform, "PCam_Test_preprocessed") |
|
elif name == 'NCK_CRC': |
|
dataset = NCK_CRC(root, transform, "CRC-VAL-HE-7K") |
|
elif name == 'BACH': |
|
dataset = BACH(root, transform, "BACH") |
|
elif name == 'Osteo': |
|
dataset = Osteo(root, transform, "Osteosarcoma") |
|
elif name == 'skin_cancer': |
|
dataset = Skin_datasets(root, transform, "skin_tumor", 'cancer') |
|
elif name == 'skin_tumor': |
|
dataset = Skin_datasets(root, transform, "skin_tumor", 'tumor') |
|
elif name == 'refuge_retina': |
|
dataset = Retina_datasets(root, transform, '25_REFUGE.csv') |
|
elif name == 'five_retina': |
|
dataset = Retina_datasets(root, transform, '13_FIVES.csv') |
|
elif name == 'odir_retina': |
|
dataset = Retina_datasets(root, transform, '08_ODIR200x3.csv') |
|
|
|
elif entry['type'] == 'filelist': |
|
path = entry['train'] if is_train else entry['test'] |
|
val_images = os.path.join(root, path + '_images.npy') |
|
val_labels = os.path.join(root, path + '_labels.npy') |
|
if name == 'CLEVRCounts': |
|
target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8', |
|
'count_9'].index(x) |
|
else: |
|
target_transform = None |
|
dataset = FileListDataset(val_images, val_labels, transform, target_transform) |
|
else: |
|
raise Exception('Unknown dataset') |
|
|
|
return dataset |
|
|
|
|
|
class ZeroShotImageDataset(torch.utils.data.Dataset): |
|
def __init__(self, |
|
datalist=['chexpert-5x200'], |
|
class_names=None, |
|
imgtransform=None, |
|
parent_data_path="", |
|
) -> None: |
|
'''support data list in mimic-5x200, chexpert-5x200, rsna-balanced-test, covid-test |
|
args: |
|
imgtransform: a torchvision transform |
|
cls_prompts: a dict of prompt sentences, cls:[sent1, sent2, ..], |
|
''' |
|
super().__init__() |
|
|
|
self.transform = imgtransform |
|
|
|
self.class_names = class_names |
|
self.parent_data_path = parent_data_path |
|
|
|
df_list = [] |
|
for data in datalist: |
|
filename = f'./local_data/{data}.csv' |
|
print('load data from', filename) |
|
df = pd.read_csv(filename, index_col=0) |
|
df_list.append(df) |
|
self.df = pd.concat(df_list, axis=0).reset_index(drop=True) |
|
|
|
def __getitem__(self, index): |
|
row = self.df.iloc[index] |
|
img = Image.open(os.path.join(self.parent_data_path, row.imgpath)) |
|
|
|
img = self.transform(img) |
|
label = torch.from_numpy(row[self.class_names].values.astype(np.float_)) |
|
return img, label |
|
|
|
def _pad_img(self, img, min_size=224, fill_color=0): |
|
'''pad img to square. |
|
''' |
|
x, y = img.size |
|
size = max(min_size, x, y) |
|
new_im = Image.new('L', (size, size), fill_color) |
|
new_im.paste(img, (int((size - x) / 2), int((size - y) / 2))) |
|
return new_im |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
|
|
class RadImageNet(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.data = pd.read_csv(os.path.join(parent_path, "radimagenet_test_set_formatted.csv")) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
img_name = ast.literal_eval(img_name)[0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label = self.data.iloc[idx, 1] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class CT_dataset(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername), |
|
"annotations.csv")) |
|
|
|
self.data = all_data[all_data['split'] == 'test'] |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label = self.data.iloc[idx, 2] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class LC25000(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, split=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
if split == "lung": |
|
classes = ['lung_aca', 'lung_n', 'lung_scc'] |
|
else: |
|
classes = ['colon_aca', 'colon_n'] |
|
self.split = split |
|
self.class_name_folder = [] |
|
self.images = [] |
|
self.labels = [] |
|
for idx, single_class_folder in enumerate(classes): |
|
images_per_classes = list(os.listdir(os.path.join(parent_path, single_class_folder))) |
|
self.images = self.images + images_per_classes |
|
self.labels = self.labels + ([idx] * len(images_per_classes)) |
|
self.class_name_folder = self.class_name_folder + ([single_class_folder] * len(images_per_classes)) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
assert len(self.images) == len(self.labels) == len(self.class_name_folder) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.images[idx] |
|
class_folder_name = self.class_name_folder[idx] |
|
img = Image.open(os.path.join(os.path.join(self.parent_data_path, class_folder_name), img_name)) |
|
label = self.labels[idx] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class PCAM(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
all_files = os.listdir(os.path.join(parent_path, foldername)) |
|
|
|
self.data = all_files |
|
|
|
labels = [] |
|
for single_file in all_files: |
|
splitted_label = int(single_file.split("_")[1].split(".")[0]) |
|
labels.append(splitted_label) |
|
self.labels = labels |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
assert len(self.labels) == len(self.data) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data[idx] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label = self.labels[idx] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class NCK_CRC(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
NCK_CRC_converter = {"ADI": 0, |
|
"DEB": 1, |
|
"LYM": 2, |
|
"MUC": 3, |
|
"MUS": 4, |
|
"NORM": 5, |
|
"STR": 6, |
|
"TUM": 7, |
|
} |
|
all_data = [] |
|
all_class_names = [] |
|
all_labels = [] |
|
folder_names = os.listdir(os.path.join(parent_path, foldername)) |
|
for single_folder in folder_names: |
|
class_path = os.path.join(os.path.join(parent_path, foldername), single_folder) |
|
images_inside_folder = os.listdir(class_path) |
|
class_label = [NCK_CRC_converter[single_folder]] * len(images_inside_folder) |
|
all_data.extend(images_inside_folder) |
|
all_labels.extend(class_label) |
|
all_class_names.extend([single_folder] * len(images_inside_folder)) |
|
|
|
self.data = all_data |
|
self.labels = all_labels |
|
self.prefix_name = all_class_names |
|
assert len(self.data) == len(self.labels) == len(self.prefix_name) |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data[idx] |
|
class_name = self.prefix_name[idx] |
|
label = self.labels[idx] |
|
img = Image.open(os.path.join(self.parent_data_path, os.path.join(class_name, img_name))) |
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class BACH(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername), |
|
"microscopy_ground_truth.csv")) |
|
|
|
self.data = all_data |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
self.label_to_text_mapping = {'Normal': 3, 'Invasive': 2, 'InSitu': 1, "Benign": 0} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
label_text = self.data.iloc[idx, 1] |
|
img = Image.open(os.path.join(self.parent_data_path, label_text + "/" + img_name)) |
|
|
|
label = self.label_to_text_mapping[label_text] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class Retina_datasets(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, data=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
filename = f'./local_data/{data}' |
|
all_data = pd.read_csv(filename) |
|
|
|
self.data = all_data |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
|
|
if data == '25_REFUGE.csv': |
|
self.label_to_text_mapping = {'no glaucoma': 0, 'glaucoma': 1} |
|
elif data == '13_FIVES.csv': |
|
self.label_to_text_mapping = {"age related macular degeneration": 0, |
|
"diabetic retinopathy": 1, |
|
"glaucoma": 2, |
|
"normal": 3} |
|
elif data == '08_ODIR200x3.csv': |
|
self.label_to_text_mapping = {"normal": 0, |
|
"pathologic myopia": 1, |
|
"cataract": 2} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 1] |
|
label_text = self.data.iloc[idx, 3] |
|
label_text = ast.literal_eval(label_text)[0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
|
|
label = self.label_to_text_mapping[label_text] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
class Skin_datasets(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None, split_type='cancer'): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
if split_type == 'cancer': |
|
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername), |
|
"data/tiles-v2.csv")) |
|
|
|
all_data = all_data[all_data['set'] == 'Test'] |
|
self.label_to_text_mapping = { |
|
"nontumor_skin_necrosis_necrosis": 0, |
|
|
|
"nontumor_skin_muscle_skeletal": |
|
1, |
|
|
|
"nontumor_skin_sweatglands_sweatglands": |
|
2, |
|
|
|
"nontumor_skin_vessel_vessel": |
|
3, |
|
|
|
"nontumor_skin_elastosis_elastosis": |
|
4, |
|
|
|
"nontumor_skin_chondraltissue_chondraltissue": |
|
5, |
|
|
|
"nontumor_skin_hairfollicle_hairfollicle": |
|
6, |
|
"nontumor_skin_epidermis_epidermis": 7, |
|
"nontumor_skin_nerves_nerves": |
|
8, |
|
|
|
"nontumor_skin_subcutis_subcutis": |
|
9, |
|
|
|
"nontumor_skin_dermis_dermis": |
|
10, |
|
|
|
"nontumor_skin_sebaceousglands_sebaceousglands": |
|
11, |
|
|
|
"tumor_skin_epithelial_sqcc": |
|
12, |
|
|
|
"tumor_skin_melanoma_melanoma": |
|
13, |
|
|
|
"tumor_skin_epithelial_bcc": |
|
14, |
|
|
|
"tumor_skin_naevus_naevus": |
|
15 |
|
} |
|
else: |
|
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername), |
|
"data/SkinTumorSubset.csv")) |
|
|
|
all_data = all_data[all_data['set'] == 'Test'] |
|
self.label_to_text_mapping = {"tumor_skin_epithelial_sqcc": |
|
0, |
|
|
|
"tumor_skin_melanoma_melanoma": |
|
1, |
|
|
|
"tumor_skin_epithelial_bcc": |
|
2, |
|
|
|
"tumor_skin_naevus_naevus": |
|
3 |
|
} |
|
|
|
self.data = all_data |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 1] |
|
label_text = self.data.iloc[idx, 2] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
|
|
label = self.label_to_text_mapping[label_text] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class Osteo(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None, foldername=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
|
|
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername), |
|
"annotations_final.csv")) |
|
|
|
self.data = all_data |
|
self.transform = transform |
|
self.parent_data_path = os.path.join(parent_path, foldername) |
|
self.label_to_text_mapping = {'Viable': 2, 'Non-Tumor': 0, "Non-Viable-Tumor": 1} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
label_text = self.data.iloc[idx, 1] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
|
|
label = self.label_to_text_mapping[label_text] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class RSNA_Pneumonia(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform=None): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.data = pd.read_csv(os.path.join(parent_path, "RSNA_pneumonia_balanced_testfile.csv")) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 1] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label = self.data.iloc[idx, 2] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class thyroid_us_and_breast(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform, csv_file_name): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name)) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
|
|
self.mapping = {'malignant': 1, 'benign': 0} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label_name = self.data.iloc[idx, 1] |
|
label = self.mapping[label_name] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class meniscal_mri(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform, csv_file_name): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name)) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
self.mapping = {'p': 1, 'n': 0} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label_name = self.data.iloc[idx, 1] |
|
label = self.mapping[label_name] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|
|
|
|
class acl_mri(torch.utils.data.Dataset): |
|
def __init__(self, parent_path, transform, csv_file_name): |
|
""" |
|
Args: |
|
csv_file (string): Path to the CSV file containing image paths and labels. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name)) |
|
self.transform = transform |
|
self.parent_data_path = parent_path |
|
self.mapping = {'yes': 1, 'no': 0} |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
img_name = self.data.iloc[idx, 0] |
|
img = Image.open(os.path.join(self.parent_data_path, img_name)) |
|
label_name = self.data.iloc[idx, 1] |
|
label = self.mapping[label_name] |
|
|
|
if self.transform: |
|
image = self.transform(img) |
|
|
|
return image, label |
|
|