File size: 5,482 Bytes
933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 933ca80 f07bfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
from .collators import DataCollatorForMultitaskCellClassification
from .imports import *
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
try:
dataset = load_from_disk(dataset_path)
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
task_to_column = dict(zip(task_names, config["task_columns"]))
config["task_names"] = task_names
if not is_test:
available_columns = set(dataset.column_names)
for column in task_to_column.values():
if column not in available_columns:
raise KeyError(
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
)
label_mappings = {}
task_label_mappings = {}
cell_id_mapping = {}
num_labels_list = []
# Load or create task label mappings
if not is_test:
for task, column in task_to_column.items():
unique_values = sorted(set(dataset[column])) # Ensure consistency
label_mappings[column] = {
label: idx for idx, label in enumerate(unique_values)
}
task_label_mappings[task] = label_mappings[column]
num_labels_list.append(len(unique_values))
# Print the mappings for each task with dataset type prefix
for task, mapping in task_label_mappings.items():
print(
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
) # sanity check, for train/validation splits
# Save the task label mappings as a pickle file
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
pickle.dump(task_label_mappings, f)
else:
# Load task label mappings from pickle file for test data
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
task_label_mappings = pickle.load(f)
# Infer num_labels_list from task_label_mappings
for task, mapping in task_label_mappings.items():
num_labels_list.append(len(mapping))
# Store unique cell IDs in a separate dictionary
for idx, record in enumerate(dataset):
cell_id = record.get("unique_cell_id", idx)
cell_id_mapping[idx] = cell_id
# Transform records to the desired format
transformed_dataset = []
for idx, record in enumerate(dataset):
transformed_record = {}
transformed_record["input_ids"] = torch.tensor(
record["input_ids"], dtype=torch.long
)
# Use index-based cell ID for internal tracking
transformed_record["cell_id"] = idx
if not is_test:
# Prepare labels
label_dict = {}
for task, column in task_to_column.items():
label_value = record[column]
label_index = task_label_mappings[task][label_value]
label_dict[task] = label_index
transformed_record["label"] = label_dict
else:
# Create dummy labels for test data
label_dict = {task: -1 for task in config["task_names"]}
transformed_record["label"] = label_dict
transformed_dataset.append(transformed_record)
return transformed_dataset, cell_id_mapping, num_labels_list
except KeyError as e:
print(f"Missing configuration or dataset key: {e}")
except Exception as e:
print(f"An error occurred while loading or preprocessing data: {e}")
return None, None, None
def preload_and_process_data(config):
# Load and preprocess data once
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
config["train_path"], config, dataset_type="train"
)
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
config["val_path"], config, dataset_type="validation"
)
return (
train_dataset,
train_cell_id_mapping,
val_dataset,
val_cell_id_mapping,
num_labels_list,
)
def get_data_loader(preprocessed_dataset, batch_size):
nproc = os.cpu_count() ### I/O operations
data_collator = DataCollatorForMultitaskCellClassification()
loader = DataLoader(
preprocessed_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=nproc,
pin_memory=True,
)
return loader
def preload_data(config):
# Preprocessing the data before the Optuna trials start
train_loader = get_data_loader("train", config)
val_loader = get_data_loader("val", config)
return train_loader, val_loader
def load_and_preprocess_test_data(config):
"""
Load and preprocess test data, treating it as unlabeled.
"""
return load_and_preprocess_data(config["test_path"], config, is_test=True)
def prepare_test_loader(config):
"""
Prepare DataLoader for the test dataset.
"""
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
config
)
test_loader = get_data_loader(test_dataset, config["batch_size"])
return test_loader, cell_id_mapping, num_labels_list
|