muzairkhattak
first commit for the demo
37b3db0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import json
import os
import pickle
import zipfile
import pandas as pd
import numpy as np
from PIL import Image, ImageFile
import ast
import torch
import random
from constants import CHEXPERT_COMPETITION_TASKS
from torchvision import transforms
from torchvision import datasets as t_datasets
import ast
ImageFile.LOAD_TRUNCATED_IMAGES = True
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class FileListDataset(torch.utils.data.Dataset):
def __init__(self, images, labels, transform=None, target_transform=None):
self.transform = transform
self.target_transform = target_transform
self.images = np.load(images)
self.labels = np.load(labels)
def __getitem__(self, index):
img = pil_loader(self.images[index])
target = self.labels[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.images)
def get_downstream_dataset(catalog, name, is_train, transform):
entry = catalog[name]
root = entry['path']
if entry['type'] == 'imagefolder':
dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']),
transform=transform)
elif entry['type'] == 'special':
if name == 'CIFAR10':
dataset = t_datasets.CIFAR10(root, train=is_train,
transform=transform, download=True)
elif name == 'CIFAR100':
dataset = t_datasets.CIFAR100(root, train=is_train,
transform=transform, download=True)
elif name == 'STL10':
dataset = t_datasets.STL10(root, split='train' if is_train else 'test',
transform=transform, download=True)
elif name == 'MNIST':
dataset = t_datasets.MNIST(root, train=is_train,
transform=transform, download=True)
elif name == 'chexpert-5x200':
dataset = ZeroShotImageDataset(['chexpert-5x200'], CHEXPERT_COMPETITION_TASKS, transform
, parent_data_path=root)
elif name == "radimagenet":
dataset = RadImageNet(root, transform)
elif name == "rsna_pneumonia":
dataset = RSNA_Pneumonia(root, transform)
elif name == "thyroid_us":
dataset = thyroid_us_and_breast(root, transform, "thyroid_test_fold1.csv")
elif name == "breast_us":
dataset = thyroid_us_and_breast(root, transform, "breast_test_fold1.csv")
elif name == "meniscal_mri":
dataset = meniscal_mri(root, transform, "meniscus_test_fold1.csv")
elif name == 'acl_mri':
dataset = acl_mri(root, transform, "test_fold1.csv")
elif name == 'CT_axial':
dataset = CT_dataset(root, transform, "organs_axial")
elif name == 'CT_coronal':
dataset = CT_dataset(root, transform, "organs_coronal")
elif name == 'CT_sagittal':
dataset = CT_dataset(root, transform, "organs_sagittal")
elif name == 'dr_regular':
dataset = CT_dataset(root, transform, "dr_regular")
elif name == 'dr_uwf':
dataset = CT_dataset(root, transform, "dr_uwf")
elif name == 'LC25000_lung':
dataset = LC25000(root, transform, "lung")
elif name == 'LC25000_colon':
dataset = LC25000(root, transform, "colon")
elif name == 'PCAM':
dataset = PCAM(root, transform, "PCam_Test_preprocessed")
elif name == 'NCK_CRC':
dataset = NCK_CRC(root, transform, "CRC-VAL-HE-7K")
elif name == 'BACH':
dataset = BACH(root, transform, "BACH")
elif name == 'Osteo':
dataset = Osteo(root, transform, "Osteosarcoma")
elif name == 'skin_cancer':
dataset = Skin_datasets(root, transform, "skin_tumor", 'cancer')
elif name == 'skin_tumor':
dataset = Skin_datasets(root, transform, "skin_tumor", 'tumor')
elif name == 'refuge_retina':
dataset = Retina_datasets(root, transform, '25_REFUGE.csv')
elif name == 'five_retina':
dataset = Retina_datasets(root, transform, '13_FIVES.csv')
elif name == 'odir_retina':
dataset = Retina_datasets(root, transform, '08_ODIR200x3.csv')
elif entry['type'] == 'filelist':
path = entry['train'] if is_train else entry['test']
val_images = os.path.join(root, path + '_images.npy')
val_labels = os.path.join(root, path + '_labels.npy')
if name == 'CLEVRCounts':
target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8',
'count_9'].index(x)
else:
target_transform = None
dataset = FileListDataset(val_images, val_labels, transform, target_transform)
else:
raise Exception('Unknown dataset')
return dataset
class ZeroShotImageDataset(torch.utils.data.Dataset):
def __init__(self,
datalist=['chexpert-5x200'],
class_names=None,
imgtransform=None,
parent_data_path="",
) -> None:
'''support data list in mimic-5x200, chexpert-5x200, rsna-balanced-test, covid-test
args:
imgtransform: a torchvision transform
cls_prompts: a dict of prompt sentences, cls:[sent1, sent2, ..],
'''
super().__init__()
self.transform = imgtransform
self.class_names = class_names
self.parent_data_path = parent_data_path
# imgpath, subject_id, report, labels...(14 labels)
df_list = []
for data in datalist:
filename = f'./local_data/{data}.csv'
print('load data from', filename)
df = pd.read_csv(filename, index_col=0)
df_list.append(df)
self.df = pd.concat(df_list, axis=0).reset_index(drop=True)
def __getitem__(self, index):
row = self.df.iloc[index]
img = Image.open(os.path.join(self.parent_data_path, row.imgpath))
# img = self._pad_img(img)
img = self.transform(img)
label = torch.from_numpy(row[self.class_names].values.astype(np.float_))
return img, label
def _pad_img(self, img, min_size=224, fill_color=0):
'''pad img to square.
'''
x, y = img.size
size = max(min_size, x, y)
new_im = Image.new('L', (size, size), fill_color)
new_im.paste(img, (int((size - x) / 2), int((size - y) / 2)))
return new_im
def __len__(self):
return len(self.df)
class RadImageNet(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data = pd.read_csv(os.path.join(parent_path, "radimagenet_test_set_formatted.csv"))
self.transform = transform
self.parent_data_path = parent_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
img_name = ast.literal_eval(img_name)[0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.data.iloc[idx, 1]
if self.transform:
image = self.transform(img)
return image, label
class CT_dataset(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername),
"annotations.csv"))
# Filter the data to only retain the test samples
self.data = all_data[all_data['split'] == 'test']
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.data.iloc[idx, 2]
if self.transform:
image = self.transform(img)
return image, label
class LC25000(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, split=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
if split == "lung":
classes = ['lung_aca', 'lung_n', 'lung_scc']
else:
classes = ['colon_aca', 'colon_n']
self.split = split
self.class_name_folder = []
self.images = []
self.labels = []
for idx, single_class_folder in enumerate(classes):
images_per_classes = list(os.listdir(os.path.join(parent_path, single_class_folder)))
self.images = self.images + images_per_classes
self.labels = self.labels + ([idx] * len(images_per_classes))
self.class_name_folder = self.class_name_folder + ([single_class_folder] * len(images_per_classes))
self.transform = transform
self.parent_data_path = parent_path
assert len(self.images) == len(self.labels) == len(self.class_name_folder)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.images[idx]
class_folder_name = self.class_name_folder[idx]
img = Image.open(os.path.join(os.path.join(self.parent_data_path, class_folder_name), img_name))
label = self.labels[idx]
if self.transform:
image = self.transform(img)
return image, label
class PCAM(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
all_files = os.listdir(os.path.join(parent_path, foldername))
# Filter the data to only retain the test samples
self.data = all_files
# Create labels
labels = []
for single_file in all_files:
splitted_label = int(single_file.split("_")[1].split(".")[0])
labels.append(splitted_label)
self.labels = labels
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
assert len(self.labels) == len(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data[idx]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.labels[idx]
if self.transform:
image = self.transform(img)
return image, label
class NCK_CRC(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
NCK_CRC_converter = {"ADI": 0,
"DEB": 1,
"LYM": 2,
"MUC": 3,
"MUS": 4,
"NORM": 5,
"STR": 6,
"TUM": 7,
}
all_data = []
all_class_names = []
all_labels = []
folder_names = os.listdir(os.path.join(parent_path, foldername))
for single_folder in folder_names:
class_path = os.path.join(os.path.join(parent_path, foldername), single_folder)
images_inside_folder = os.listdir(class_path)
class_label = [NCK_CRC_converter[single_folder]] * len(images_inside_folder)
all_data.extend(images_inside_folder)
all_labels.extend(class_label)
all_class_names.extend([single_folder] * len(images_inside_folder))
# Filter the data to only retain the test samples
self.data = all_data
self.labels = all_labels
self.prefix_name = all_class_names
assert len(self.data) == len(self.labels) == len(self.prefix_name)
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data[idx]
class_name = self.prefix_name[idx]
label = self.labels[idx]
img = Image.open(os.path.join(self.parent_data_path, os.path.join(class_name, img_name)))
if self.transform:
image = self.transform(img)
return image, label
class BACH(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername),
"microscopy_ground_truth.csv"))
self.data = all_data
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
self.label_to_text_mapping = {'Normal': 3, 'Invasive': 2, 'InSitu': 1, "Benign": 0}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
label_text = self.data.iloc[idx, 1]
img = Image.open(os.path.join(self.parent_data_path, label_text + "/" + img_name))
label = self.label_to_text_mapping[label_text]
if self.transform:
image = self.transform(img)
return image, label
class Retina_datasets(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, data=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
filename = f'./local_data/{data}'
all_data = pd.read_csv(filename)
self.data = all_data
self.transform = transform
self.parent_data_path = parent_path
if data == '25_REFUGE.csv':
self.label_to_text_mapping = {'no glaucoma': 0, 'glaucoma': 1}
elif data == '13_FIVES.csv':
self.label_to_text_mapping = {"age related macular degeneration": 0,
"diabetic retinopathy": 1,
"glaucoma": 2,
"normal": 3}
elif data == '08_ODIR200x3.csv':
self.label_to_text_mapping = {"normal": 0,
"pathologic myopia": 1,
"cataract": 2}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 1]
label_text = self.data.iloc[idx, 3] # it is a list
label_text = ast.literal_eval(label_text)[0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.label_to_text_mapping[label_text]
if self.transform:
image = self.transform(img)
return image, label
class Skin_datasets(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None, split_type='cancer'):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
if split_type == 'cancer':
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername),
"data/tiles-v2.csv"))
# Filter the dataset and take only test samples...
all_data = all_data[all_data['set'] == 'Test']
self.label_to_text_mapping = {
"nontumor_skin_necrosis_necrosis": 0,
"nontumor_skin_muscle_skeletal":
1,
"nontumor_skin_sweatglands_sweatglands":
2,
"nontumor_skin_vessel_vessel":
3,
"nontumor_skin_elastosis_elastosis":
4,
"nontumor_skin_chondraltissue_chondraltissue":
5,
"nontumor_skin_hairfollicle_hairfollicle":
6,
"nontumor_skin_epidermis_epidermis": 7,
"nontumor_skin_nerves_nerves":
8,
"nontumor_skin_subcutis_subcutis":
9,
"nontumor_skin_dermis_dermis":
10,
"nontumor_skin_sebaceousglands_sebaceousglands":
11,
"tumor_skin_epithelial_sqcc":
12,
"tumor_skin_melanoma_melanoma":
13,
"tumor_skin_epithelial_bcc":
14,
"tumor_skin_naevus_naevus":
15
}
else:
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername),
"data/SkinTumorSubset.csv"))
# Filter the dataset and take only test samples...
all_data = all_data[all_data['set'] == 'Test']
self.label_to_text_mapping = {"tumor_skin_epithelial_sqcc":
0,
"tumor_skin_melanoma_melanoma":
1,
"tumor_skin_epithelial_bcc":
2,
"tumor_skin_naevus_naevus":
3
}
self.data = all_data
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 1]
label_text = self.data.iloc[idx, 2]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.label_to_text_mapping[label_text]
if self.transform:
image = self.transform(img)
return image, label
class Osteo(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None, foldername=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
all_data = pd.read_csv(os.path.join(os.path.join(parent_path, foldername),
"annotations_final.csv"))
self.data = all_data
self.transform = transform
self.parent_data_path = os.path.join(parent_path, foldername)
self.label_to_text_mapping = {'Viable': 2, 'Non-Tumor': 0, "Non-Viable-Tumor": 1}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
label_text = self.data.iloc[idx, 1]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.label_to_text_mapping[label_text]
if self.transform:
image = self.transform(img)
return image, label
class RSNA_Pneumonia(torch.utils.data.Dataset):
def __init__(self, parent_path, transform=None):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data = pd.read_csv(os.path.join(parent_path, "RSNA_pneumonia_balanced_testfile.csv"))
self.transform = transform
self.parent_data_path = parent_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 1]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label = self.data.iloc[idx, 2]
if self.transform:
image = self.transform(img)
return image, label
class thyroid_us_and_breast(torch.utils.data.Dataset):
def __init__(self, parent_path, transform, csv_file_name):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name))
self.transform = transform
self.parent_data_path = parent_path
# self.mapping = {'malignant': 0, 'benign': 1}
self.mapping = {'malignant': 1, 'benign': 0}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label_name = self.data.iloc[idx, 1]
label = self.mapping[label_name]
if self.transform:
image = self.transform(img)
return image, label
class meniscal_mri(torch.utils.data.Dataset):
def __init__(self, parent_path, transform, csv_file_name):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name))
self.transform = transform
self.parent_data_path = parent_path
self.mapping = {'p': 1, 'n': 0}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label_name = self.data.iloc[idx, 1]
label = self.mapping[label_name]
if self.transform:
image = self.transform(img)
return image, label
class acl_mri(torch.utils.data.Dataset):
def __init__(self, parent_path, transform, csv_file_name):
"""
Args:
csv_file (string): Path to the CSV file containing image paths and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.data = pd.read_csv(os.path.join(parent_path, csv_file_name))
self.transform = transform
self.parent_data_path = parent_path
self.mapping = {'yes': 1, 'no': 0}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.data.iloc[idx, 0]
img = Image.open(os.path.join(self.parent_data_path, img_name))
label_name = self.data.iloc[idx, 1]
label = self.mapping[label_name]
if self.transform:
image = self.transform(img)
return image, label